277 lines
12 KiB
Python
277 lines
12 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""RAG Inventory API — global knowledge-store visibility for users, admins, platform."""
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import APIRouter, HTTPException, Depends, Request
|
|
from modules.auth import limiter, getCurrentUser, getRequestContext, RequestContext
|
|
from modules.datamodels.datamodelUam import User
|
|
from modules.shared.i18nRegistry import apiRouteContext
|
|
|
|
routeApiMsg = apiRouteContext("routeRagInventory")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(
|
|
prefix="/api/rag/inventory",
|
|
tags=["RAG Inventory"],
|
|
responses={
|
|
401: {"description": "Unauthorized"},
|
|
403: {"description": "Forbidden"},
|
|
500: {"description": "Internal server error"},
|
|
},
|
|
)
|
|
|
|
|
|
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
|
|
from modules.datamodels.datamodelDataSource import DataSource
|
|
from modules.datamodels.datamodelKnowledge import FileContentIndex
|
|
|
|
out = []
|
|
for conn in connections:
|
|
connectionId = str(conn.id)
|
|
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
|
|
|
|
connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId})
|
|
connChunkTotal = len(connIndexRows)
|
|
|
|
chunksByDs: Dict[str, int] = {}
|
|
unassigned = 0
|
|
for idx in connIndexRows:
|
|
prov = (idx.get("provenance") if isinstance(idx, dict) else getattr(idx, "provenance", None)) or {}
|
|
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
|
|
if dsIdRef:
|
|
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1
|
|
else:
|
|
unassigned += 1
|
|
|
|
dsItems = []
|
|
for ds in dataSources:
|
|
dsId = ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")
|
|
dsItems.append({
|
|
"id": dsId,
|
|
"label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""),
|
|
"path": ds.get("path") if isinstance(ds, dict) else getattr(ds, "path", ""),
|
|
"sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""),
|
|
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
|
|
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
|
|
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
|
|
"chunkCount": chunksByDs.get(dsId, 0),
|
|
})
|
|
|
|
if unassigned > 0 and len(dsItems) == 1:
|
|
dsItems[0]["chunkCount"] += unassigned
|
|
|
|
jobs = jobService.listJobs(jobType="connection.bootstrap", limit=5)
|
|
connJobs = [j for j in jobs if (j.get("payload") or {}).get("connectionId") == connectionId]
|
|
runningJobs = [
|
|
{"jobId": j["id"], "progress": j.get("progress", 0), "progressMessage": j.get("progressMessage", "")}
|
|
for j in connJobs
|
|
if j.get("status") in ("PENDING", "RUNNING")
|
|
]
|
|
lastError = None
|
|
for j in connJobs:
|
|
if j.get("status") == "ERROR":
|
|
lastError = {"jobId": j["id"], "errorMessage": j.get("errorMessage", "")}
|
|
break
|
|
|
|
out.append({
|
|
"id": connectionId,
|
|
"authority": conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority),
|
|
"externalEmail": getattr(conn, "externalEmail", ""),
|
|
"knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False),
|
|
"preferences": getattr(conn, "knowledgePreferences", None) or {},
|
|
"dataSources": dsItems,
|
|
"totalChunks": connChunkTotal,
|
|
"runningJobs": runningJobs,
|
|
"lastError": lastError,
|
|
})
|
|
return out
|
|
|
|
|
|
@router.get("/me")
|
|
@limiter.limit("30/minute")
|
|
def _getInventoryMe(
|
|
request: Request,
|
|
currentUser: User = Depends(getCurrentUser),
|
|
) -> Dict[str, Any]:
|
|
"""Personal RAG inventory: own connections + DataSources + chunk counts."""
|
|
try:
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
|
|
from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
|
|
|
|
rootIf = getRootInterface()
|
|
knowledgeIf = getKnowledgeInterface(None)
|
|
connections = rootIf.getUserConnections(currentUser.id)
|
|
|
|
items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService)
|
|
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
|
|
|
return {"connections": items, "totals": {"chunks": totalChunks}}
|
|
except Exception as e:
|
|
logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/mandate")
|
|
@limiter.limit("20/minute")
|
|
def _getInventoryMandate(
|
|
request: Request,
|
|
context: RequestContext = Depends(getRequestContext),
|
|
) -> Dict[str, Any]:
|
|
"""Mandate-level RAG aggregation (requires mandate membership)."""
|
|
if not context.mandateId:
|
|
raise HTTPException(status_code=403, detail=routeApiMsg("Mandate context required"))
|
|
try:
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface, aggregateMandateRagTotalBytes
|
|
from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
|
|
|
|
rootIf = getRootInterface()
|
|
knowledgeIf = getKnowledgeInterface(None)
|
|
mandateId = str(context.mandateId) if context.mandateId else ""
|
|
|
|
from modules.datamodels.datamodelUam import UserConnection
|
|
allConnections = rootIf.db.getRecordset(UserConnection, recordFilter={"mandateId": mandateId})
|
|
connectionObjects = [type("C", (), row)() if isinstance(row, dict) else row for row in allConnections]
|
|
|
|
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
|
|
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
|
totalBytes = aggregateMandateRagTotalBytes(mandateId)
|
|
|
|
return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Error in RAG inventory /mandate: %s", e, exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/platform")
|
|
@limiter.limit("10/minute")
|
|
def _getInventoryPlatform(
|
|
request: Request,
|
|
context: RequestContext = Depends(getRequestContext),
|
|
) -> Dict[str, Any]:
|
|
"""Platform-wide RAG statistics (sysadmin only)."""
|
|
if not context.isSysAdmin:
|
|
raise HTTPException(status_code=403, detail=routeApiMsg("Platform admin required"))
|
|
try:
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
|
|
from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
|
|
from modules.datamodels.datamodelUam import UserConnection
|
|
|
|
rootIf = getRootInterface()
|
|
knowledgeIf = getKnowledgeInterface(None)
|
|
allConnections = rootIf.db.getRecordset(UserConnection)
|
|
connectionObjects = [type("C", (), row)() if isinstance(row, dict) else row for row in allConnections]
|
|
|
|
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
|
|
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
|
|
|
return {"connections": items, "totals": {"chunks": totalChunks}}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Error in RAG inventory /platform: %s", e, exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/reindex/{connectionId}")
|
|
@limiter.limit("10/minute")
|
|
def _reindexConnection(
|
|
request: Request,
|
|
connectionId: str,
|
|
currentUser: User = Depends(getCurrentUser),
|
|
) -> Dict[str, Any]:
|
|
"""Re-trigger bootstrap for a connection (re-index all ragIndexEnabled DataSources).
|
|
|
|
Submits a new connection.bootstrap job, regardless of previous failures.
|
|
"""
|
|
try:
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
|
|
from modules.datamodels.datamodelDataSource import DataSource
|
|
import asyncio
|
|
|
|
rootIf = getRootInterface()
|
|
conn = rootIf.getUserConnectionById(connectionId)
|
|
if conn is None:
|
|
raise HTTPException(status_code=404, detail="Connection not found")
|
|
|
|
if str(conn.userId) != str(currentUser.id):
|
|
raise HTTPException(status_code=403, detail="Not your connection")
|
|
|
|
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
|
|
ragDs = [ds for ds in dataSources if (ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False))]
|
|
if not ragDs:
|
|
return {"status": "skipped", "reason": "no_rag_enabled_datasources"}
|
|
|
|
authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "")
|
|
dsIds = [(ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")) for ds in ragDs]
|
|
|
|
async def _enqueue():
|
|
return await startJob(
|
|
"connection.bootstrap",
|
|
{"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": dsIds},
|
|
triggeredBy=str(currentUser.id),
|
|
)
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
future = asyncio.ensure_future(_enqueue())
|
|
jobId = None
|
|
else:
|
|
jobId = loop.run_until_complete(_enqueue())
|
|
except RuntimeError:
|
|
jobId = asyncio.run(_enqueue())
|
|
|
|
logger.info("Reindex triggered for connection %s (%d DataSources)", connectionId, len(dsIds))
|
|
return {"status": "queued", "connectionId": connectionId, "dataSourceCount": len(dsIds), "jobId": jobId}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Error triggering reindex: %s", e, exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/jobs")
|
|
@limiter.limit("60/minute")
|
|
def _getActiveJobs(
|
|
request: Request,
|
|
currentUser: User = Depends(getCurrentUser),
|
|
) -> List[Dict[str, Any]]:
|
|
"""Active RAG jobs for the current user (used by header badge)."""
|
|
try:
|
|
from modules.serviceCenter.services.serviceBackgroundJobs import listJobs
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
|
|
rootIf = getRootInterface()
|
|
connections = rootIf.getUserConnections(currentUser.id)
|
|
connectionMap = {str(c.id): c for c in connections}
|
|
connectionIds = set(connectionMap.keys())
|
|
|
|
jobs = listJobs(jobType="connection.bootstrap", limit=50)
|
|
active = []
|
|
for j in jobs:
|
|
if j.get("status") not in ("PENDING", "RUNNING"):
|
|
continue
|
|
payload = j.get("payload") or {}
|
|
connId = payload.get("connectionId")
|
|
if connId in connectionIds:
|
|
conn = connectionMap[connId]
|
|
active.append({
|
|
"jobId": j["id"],
|
|
"connectionId": connId,
|
|
"connectionLabel": getattr(conn, "displayLabel", None) or getattr(conn, "authority", connId),
|
|
"jobType": j.get("jobType", "connection.bootstrap"),
|
|
"progress": j.get("progress", 0),
|
|
"progressMessage": j.get("progressMessage", ""),
|
|
})
|
|
return active
|
|
except Exception as e:
|
|
logger.error("Error in RAG inventory /jobs: %s", e, exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|