# Copyright (c) 2025 Patrick Motsch # All rights reserved. """RAG Inventory API — global knowledge-store visibility for users, admins, platform.""" import logging from typing import Any, Dict, List, Optional from fastapi import APIRouter, HTTPException, Depends, Request from modules.auth import limiter, getCurrentUser, getRequestContext, RequestContext from modules.datamodels.datamodelUam import User from modules.shared.i18nRegistry import apiRouteContext routeApiMsg = apiRouteContext("routeRagInventory") logger = logging.getLogger(__name__) router = APIRouter( prefix="/api/rag/inventory", tags=["RAG Inventory"], responses={ 401: {"description": "Unauthorized"}, 403: {"description": "Forbidden"}, 500: {"description": "Internal server error"}, }, ) def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]: from modules.datamodels.datamodelDataSource import DataSource from modules.datamodels.datamodelKnowledge import FileContentIndex out = [] for conn in connections: connectionId = str(conn.id) dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId}) connChunkTotal = len(connIndexRows) chunksByDs: Dict[str, int] = {} unassigned = 0 for idx in connIndexRows: prov = (idx.get("provenance") if isinstance(idx, dict) else getattr(idx, "provenance", None)) or {} dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else "" if dsIdRef: chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1 else: unassigned += 1 dsItems = [] for ds in dataSources: dsId = ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "") dsItems.append({ "id": dsId, "label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""), "path": ds.get("path") if isinstance(ds, dict) else getattr(ds, "path", ""), "sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""), "ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False), "neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False), "lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None), "chunkCount": chunksByDs.get(dsId, 0), }) if unassigned > 0 and len(dsItems) == 1: dsItems[0]["chunkCount"] += unassigned jobs = jobService.listJobs(jobType="connection.bootstrap", limit=5) connJobs = [j for j in jobs if (j.get("payload") or {}).get("connectionId") == connectionId] runningJobs = [ {"jobId": j["id"], "progress": j.get("progress", 0), "progressMessage": j.get("progressMessage", "")} for j in connJobs if j.get("status") in ("PENDING", "RUNNING") ] lastError = None for j in connJobs: if j.get("status") == "ERROR": lastError = {"jobId": j["id"], "errorMessage": j.get("errorMessage", "")} break out.append({ "id": connectionId, "authority": conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority), "externalEmail": getattr(conn, "externalEmail", ""), "knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False), "preferences": getattr(conn, "knowledgePreferences", None) or {}, "dataSources": dsItems, "totalChunks": connChunkTotal, "runningJobs": runningJobs, "lastError": lastError, }) return out @router.get("/me") @limiter.limit("30/minute") def _getInventoryMe( request: Request, currentUser: User = Depends(getCurrentUser), ) -> Dict[str, Any]: """Personal RAG inventory: own connections + DataSources + chunk counts.""" try: from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService rootIf = getRootInterface() knowledgeIf = getKnowledgeInterface(None) connections = rootIf.getUserConnections(currentUser.id) items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) totalChunks = sum(c.get("totalChunks", 0) for c in items) return {"connections": items, "totals": {"chunks": totalChunks}} except Exception as e: logger.error("Error in RAG inventory /me: %s", e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.get("/mandate") @limiter.limit("20/minute") def _getInventoryMandate( request: Request, context: RequestContext = Depends(getRequestContext), ) -> Dict[str, Any]: """Mandate-level RAG aggregation (requires mandate membership).""" if not context.mandateId: raise HTTPException(status_code=403, detail=routeApiMsg("Mandate context required")) try: from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface, aggregateMandateRagTotalBytes from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService rootIf = getRootInterface() knowledgeIf = getKnowledgeInterface(None) mandateId = str(context.mandateId) if context.mandateId else "" from modules.datamodels.datamodelUam import UserConnection allConnections = rootIf.db.getRecordset(UserConnection, recordFilter={"mandateId": mandateId}) connectionObjects = [type("C", (), row)() if isinstance(row, dict) else row for row in allConnections] items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) totalChunks = sum(c.get("totalChunks", 0) for c in items) totalBytes = aggregateMandateRagTotalBytes(mandateId) return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}} except HTTPException: raise except Exception as e: logger.error("Error in RAG inventory /mandate: %s", e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.get("/platform") @limiter.limit("10/minute") def _getInventoryPlatform( request: Request, context: RequestContext = Depends(getRequestContext), ) -> Dict[str, Any]: """Platform-wide RAG statistics (sysadmin only).""" if not context.isSysAdmin: raise HTTPException(status_code=403, detail=routeApiMsg("Platform admin required")) try: from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService from modules.datamodels.datamodelUam import UserConnection rootIf = getRootInterface() knowledgeIf = getKnowledgeInterface(None) allConnections = rootIf.db.getRecordset(UserConnection) connectionObjects = [type("C", (), row)() if isinstance(row, dict) else row for row in allConnections] items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) totalChunks = sum(c.get("totalChunks", 0) for c in items) return {"connections": items, "totals": {"chunks": totalChunks}} except HTTPException: raise except Exception as e: logger.error("Error in RAG inventory /platform: %s", e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.post("/reindex/{connectionId}") @limiter.limit("10/minute") def _reindexConnection( request: Request, connectionId: str, currentUser: User = Depends(getCurrentUser), ) -> Dict[str, Any]: """Re-trigger bootstrap for a connection (re-index all ragIndexEnabled DataSources). Submits a new connection.bootstrap job, regardless of previous failures. """ try: from modules.interfaces.interfaceDbApp import getRootInterface from modules.serviceCenter.services.serviceBackgroundJobs import startJob from modules.datamodels.datamodelDataSource import DataSource import asyncio rootIf = getRootInterface() conn = rootIf.getUserConnectionById(connectionId) if conn is None: raise HTTPException(status_code=404, detail="Connection not found") if str(conn.userId) != str(currentUser.id): raise HTTPException(status_code=403, detail="Not your connection") dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) ragDs = [ds for ds in dataSources if (ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False))] if not ragDs: return {"status": "skipped", "reason": "no_rag_enabled_datasources"} authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "") dsIds = [(ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")) for ds in ragDs] async def _enqueue(): return await startJob( "connection.bootstrap", {"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": dsIds}, triggeredBy=str(currentUser.id), ) try: loop = asyncio.get_event_loop() if loop.is_running(): future = asyncio.ensure_future(_enqueue()) jobId = None else: jobId = loop.run_until_complete(_enqueue()) except RuntimeError: jobId = asyncio.run(_enqueue()) logger.info("Reindex triggered for connection %s (%d DataSources)", connectionId, len(dsIds)) return {"status": "queued", "connectionId": connectionId, "dataSourceCount": len(dsIds), "jobId": jobId} except HTTPException: raise except Exception as e: logger.error("Error triggering reindex: %s", e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @router.get("/jobs") @limiter.limit("60/minute") def _getActiveJobs( request: Request, currentUser: User = Depends(getCurrentUser), ) -> List[Dict[str, Any]]: """Active RAG jobs for the current user (used by header badge).""" try: from modules.serviceCenter.services.serviceBackgroundJobs import listJobs from modules.interfaces.interfaceDbApp import getRootInterface rootIf = getRootInterface() connections = rootIf.getUserConnections(currentUser.id) connectionMap = {str(c.id): c for c in connections} connectionIds = set(connectionMap.keys()) jobs = listJobs(jobType="connection.bootstrap", limit=50) active = [] for j in jobs: if j.get("status") not in ("PENDING", "RUNNING"): continue payload = j.get("payload") or {} connId = payload.get("connectionId") if connId in connectionIds: conn = connectionMap[connId] active.append({ "jobId": j["id"], "connectionId": connId, "connectionLabel": getattr(conn, "displayLabel", None) or getattr(conn, "authority", connId), "jobType": j.get("jobType", "connection.bootstrap"), "progress": j.get("progress", 0), "progressMessage": j.get("progressMessage", ""), }) return active except Exception as e: logger.error("Error in RAG inventory /jobs: %s", e, exc_info=True) raise HTTPException(status_code=500, detail=str(e))