diff --git a/modules/auth/trustedDeviceService.py b/modules/auth/trustedDeviceService.py new file mode 100644 index 00000000..a707e70a --- /dev/null +++ b/modules/auth/trustedDeviceService.py @@ -0,0 +1,171 @@ +# Copyright (c) 2026 PowerOn AG +# All rights reserved. +""" +Trusted Device Service. + +After successful MFA verification a device can be marked as trusted for a +configurable duration (default 60 days). On subsequent logins from the same +device the MFA step is skipped. + +Cookie: ``mfa_trusted`` (httpOnly, Secure, SameSite policy from jwtService). +DB: ``TrustedDevice`` table in poweron_app. + +Regulatory basis: +- NIST SP 800-63B Section 5.2.8: Verifier MAY re-authenticate only after a + configurable period when a device is bound to the subscriber. +- Microsoft, Google, AWS implement identical patterns. +""" + +import logging +import secrets +from typing import Optional + +from fastapi import Request, Response + +from modules.shared.configuration import APP_CONFIG +from modules.shared.timeUtils import getUtcNow, getUtcTimestamp +from modules.datamodels.datamodelSecurity import TrustedDevice + +logger = logging.getLogger(__name__) + +_COOKIE_NAME = "mfa_trusted" +_DEFAULT_TRUST_DAYS = 60 +_TOKEN_BYTES = 32 + + +def _getTrustDurationDays() -> int: + raw = (APP_CONFIG.get("MFA_TRUST_DURATION_DAYS") or "").strip() + if raw.isdigit() and int(raw) > 0: + return int(raw) + return _DEFAULT_TRUST_DAYS + + +def createTrustedDevice(userId: str, request: Request, response: Response, db) -> str: + """Create a TrustedDevice entry and set the cookie on the response. + + Returns the device token (cookie value). + """ + from modules.auth.jwtService import _cookiePolicy + + trustDays = _getTrustDurationDays() + deviceToken = secrets.token_urlsafe(_TOKEN_BYTES) + + now = getUtcTimestamp() + trustedUntil = now + (trustDays * 86400) + + device = TrustedDevice( + id=deviceToken, + userId=userId, + trustedUntil=trustedUntil, + userAgent=(request.headers.get("user-agent") or "")[:512], + ipAddress=_getClientIp(request), + createdAt=now, + ) + + try: + db.recordCreate(TrustedDevice, device.model_dump()) + except Exception as e: + logger.error(f"Failed to persist TrustedDevice for userId={userId}: {e}") + return "" + + useSecure, samesite, _ = _cookiePolicy() + response.set_cookie( + key=_COOKIE_NAME, + value=deviceToken, + httponly=True, + secure=useSecure, + samesite=samesite, + path="/", + max_age=trustDays * 86400, + ) + + logger.info(f"Trusted device created for userId={userId}, valid {trustDays}d") + return deviceToken + + +def isTrustedDevice(request: Request, userId: str, db) -> bool: + """Check if the current request comes from a trusted device for the given user.""" + deviceToken = request.cookies.get(_COOKIE_NAME) + if not deviceToken: + return False + + try: + records = db.getRecordset( + TrustedDevice, + recordFilter={"id": deviceToken, "userId": userId}, + ) + if not records: + return False + + device = records[0] + trustedUntil = device.get("trustedUntil", 0) + if isinstance(trustedUntil, (int, float)) and trustedUntil > getUtcTimestamp(): + return True + + return False + except Exception as e: + logger.warning(f"Error checking trusted device for userId={userId}: {e}") + return False + + +def revokeTrustedDevices(userId: str, db) -> int: + """Revoke all trusted devices for a user. Returns count of deleted entries.""" + try: + records = db.getRecordset(TrustedDevice, recordFilter={"userId": userId}) + count = 0 + for rec in records: + db.recordDelete(TrustedDevice, rec["id"]) + count += 1 + if count: + logger.info(f"Revoked {count} trusted device(s) for userId={userId}") + return count + except Exception as e: + logger.error(f"Failed to revoke trusted devices for userId={userId}: {e}") + return 0 + + +def clearTrustedDeviceCookie(response: Response) -> None: + """Clear the mfa_trusted cookie.""" + from modules.auth.jwtService import _cookiePolicy + + useSecure, samesite, samesiteHeader = _cookiePolicy() + secure_flag = "; Secure" if useSecure else "" + response.headers.append( + "Set-Cookie", + f"{_COOKIE_NAME}=deleted; Path=/; Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly{secure_flag}; SameSite={samesiteHeader}" + ) + response.delete_cookie( + key=_COOKIE_NAME, + path="/", + secure=useSecure, + httponly=True, + samesite=samesite, + ) + + +def cleanupExpiredDevices(db) -> int: + """Remove TrustedDevice entries past their trustedUntil. Returns deleted count.""" + try: + records = db.getRecordset(TrustedDevice, recordFilter={}) + now = getUtcTimestamp() + count = 0 + for rec in records: + if rec.get("trustedUntil", 0) < now: + db.recordDelete(TrustedDevice, rec["id"]) + count += 1 + if count: + logger.info(f"Cleaned up {count} expired trusted device(s)") + return count + except Exception as e: + logger.error(f"Error cleaning up expired trusted devices: {e}") + return 0 + + +def _getClientIp(request: Request) -> Optional[str]: + """Extract client IP from request (respects X-Forwarded-For).""" + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + if request.client: + return request.client.host + return None diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py index 11b406ad..115f25e4 100644 --- a/modules/connectors/connectorDbPostgre.py +++ b/modules/connectors/connectorDbPostgre.py @@ -871,6 +871,7 @@ class DatabaseConnector: ("jsonb", "TEXT"): "TEXT USING \"{col}\"::text", ("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE, ("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer", + ("text", "BOOLEAN"): "BOOLEAN USING CASE WHEN \"{col}\" IN ('true', '1', 't', 'yes') THEN TRUE ELSE FALSE END", ("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')', ("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")', ("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')', diff --git a/modules/connectors/connectorVoiceGoogle.py b/modules/connectors/connectorVoiceGoogle.py index 590fd26b..68e9ef84 100644 --- a/modules/connectors/connectorVoiceGoogle.py +++ b/modules/connectors/connectorVoiceGoogle.py @@ -19,6 +19,21 @@ from modules.shared.voiceCatalog import getDefaultVoice logger = logging.getLogger(__name__) +_STT_LANGUAGE_MAP = { + "de-CH": "de-DE", + "de-AT": "de-DE", + "en-GB": "en-US", + "en-AU": "en-US", + "fr-CH": "fr-FR", + "it-CH": "it-IT", + "pt-BR": "pt-PT", +} + + +def _normalizeSttLanguage(language: str) -> str: + """Map regional language variants to codes supported by Google STT models.""" + return _STT_LANGUAGE_MAP.get(language, language) + def _buildPrimarySttRecognitionFields( *, @@ -116,6 +131,7 @@ class ConnectorGoogleSpeech: Returns: Dict containing transcribed text, confidence, and metadata """ + language = _normalizeSttLanguage(language) try: # Treat sampleRate=0 as unknown (invalid value from client) if sampleRate is not None and sampleRate <= 0: @@ -480,6 +496,7 @@ class ConnectorGoogleSpeech: Dicts with keys: isFinal, transcript, confidence, stabilityScore, audioDurationSec; optionally endOfSingleUtterance, reconnectRequired """ + language = _normalizeSttLanguage(language) STREAM_LIMIT_SEC = 290 streamStartTs = time.time() totalAudioBytes = 0 diff --git a/modules/datamodels/datamodelSecurity.py b/modules/datamodels/datamodelSecurity.py index 280fdc9e..e8cc148f 100644 --- a/modules/datamodels/datamodelSecurity.py +++ b/modules/datamodels/datamodelSecurity.py @@ -124,6 +124,43 @@ class Token(PowerOnModel): return data +@i18nModel("Vertrauenswuerdiges Geraet") +class TrustedDevice(PowerOnModel): + """A device trusted after successful MFA verification (skips MFA for configured duration).""" + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Random token stored as httpOnly cookie value", + json_schema_extra={"label": "ID"}, + ) + userId: str = Field( + ..., + description="User this trusted device belongs to", + json_schema_extra={"label": "Benutzer-ID", "fk_target": {"db": "poweron_app", "table": "UserInDB", "labelField": "username"}}, + ) + trustedUntil: float = Field( + ..., + description="UTC timestamp until which the device is trusted", + json_schema_extra={"label": "Vertrauenswuerdig bis", "frontend_type": "timestamp"}, + ) + userAgent: Optional[str] = Field( + default=None, + description="Browser user agent at time of trust grant", + json_schema_extra={"label": "User-Agent"}, + ) + ipAddress: Optional[str] = Field( + default=None, + description="IP address at time of trust grant", + json_schema_extra={"label": "IP-Adresse"}, + ) + createdAt: float = Field( + default_factory=getUtcTimestamp, + description="When the device was trusted", + json_schema_extra={"label": "Erstellt am", "frontend_type": "timestamp"}, + ) + + model_config = ConfigDict(use_enum_values=True) + + @i18nModel("Authentifizierungsereignis") class AuthEvent(PowerOnModel): """Authentication event for audit logging.""" diff --git a/modules/features/workspace/routeFeatureWorkspace.py b/modules/features/workspace/routeFeatureWorkspace.py index df2603c4..c6d143dd 100644 --- a/modules/features/workspace/routeFeatureWorkspace.py +++ b/modules/features/workspace/routeFeatureWorkspace.py @@ -489,6 +489,158 @@ def _collectPriorFileIds(chatInterface, workflowId: str) -> List[str]: return result +# Default context budget for prior files (metadata injection per file costs tokens +# in EVERY agent round). Override per workspace instance via instanceConfig +# key "maxPriorContextFiles". Files outside the budget remain fully accessible +# to the agent via conversation history and readFile - nothing is lost. +DEFAULT_MAX_PRIOR_CONTEXT_FILES = 10 + + +async def _selectPriorFilesForContext( + priorFileIds: List[str], + prompt: str, + aiObjects, + mandateId: str, + featureInstanceId: str, + maxFiles: int, +) -> List[str]: + """Select which prior files get their metadata injected into the agent context. + + Selection only happens when more candidates exist than the budget allows - + otherwise all files pass through untouched. Selection order: + 1. Files whose indexed chunks are semantically relevant to the current prompt + 2. Remaining budget filled with the most recent files (so files without an + index entry are never unfairly dropped) + """ + if len(priorFileIds) <= maxFiles: + return priorFileIds + + rankedFileIds: List[str] = [] + try: + embeddingResponse = await aiObjects.callEmbedding([prompt]) + embeddings = (embeddingResponse.metadata or {}).get("embeddings", []) + if embeddings: + knowledgeIf = getKnowledgeInterface() + results = knowledgeIf.semanticSearch( + queryVector=embeddings[0], + featureInstanceId=featureInstanceId, + mandateId=mandateId, + limit=maxFiles * 3, + ) + priorSet = set(priorFileIds) + for chunk in results: + fid = chunk.get("fileId") if isinstance(chunk, dict) else getattr(chunk, "fileId", None) + if fid and fid in priorSet and fid not in rankedFileIds: + rankedFileIds.append(fid) + if len(rankedFileIds) >= maxFiles: + break + except Exception as e: + logger.warning(f"Relevance ranking for prior files failed: {e}") + + for fid in reversed(priorFileIds): + if len(rankedFileIds) >= maxFiles: + break + if fid not in rankedFileIds: + rankedFileIds.append(fid) + + logger.info( + f"Prior-file selection: {len(priorFileIds)} candidates -> {len(rankedFileIds)} in context " + f"(budget={maxFiles}, relevance-ranked, recency fill)" + ) + return rankedFileIds + + +async def _ensureFilesIndexed( + fileIds: List[str], + user, + mandateId: str, + featureInstanceId: str, +) -> int: + """Ensure all attached files have embeddings in the knowledge store. + + Checks FileContentIndex for each file. Files not yet indexed are extracted + and embedded inline so that subsequent RAG queries can find their content. + Indexing is idempotent (content-hash check in requestIngestion), so each + file is only ever processed once - re-attaching an indexed file is a no-op. + + Returns the number of files that were newly indexed. + """ + if not fileIds: + return 0 + + knowledgeIf = getKnowledgeInterface(user) + unindexedIds = [] + for fid in fileIds: + existing = knowledgeIf.getFileContentIndex(fid) + status = (existing.get("status") if isinstance(existing, dict) else getattr(existing, "status", "")) if existing else "" + if status not in ("indexed", "embedding"): + unindexedIds.append(fid) + + if not unindexedIds: + return 0 + + logger.info(f"Ensure-embed: {len(unindexedIds)}/{len(fileIds)} files need indexing") + + from modules.serviceCenter import getService + from modules.serviceCenter.context import ServiceCenterContext + indexCtx = ServiceCenterContext( + user=user, mandateId=mandateId, featureInstanceId=featureInstanceId, + ) + try: + knowledgeService = getService("knowledge", indexCtx) + except Exception as e: + logger.warning(f"Ensure-embed: knowledge service unavailable: {e}") + return 0 + + chatInterface = interfaceDbChat.getInterface(user) + indexed = 0 + for fid in unindexedIds: + try: + fileInfo = chatInterface.getFileInfo(fid) if chatInterface else None + if not fileInfo: + continue + fileName = fileInfo.get("fileName", "") + mimeType = fileInfo.get("mimeType", "") + rawBytes = chatInterface.getFileData(fid) + if not rawBytes: + continue + + extractionService = getService("extraction", indexCtx) + extracted = extractionService.extractContentFromBytes( + rawBytes, fileName, mimeType, documentId=fid, + ) + contentObjects = [ + { + "contentType": getattr(p, "contentType", "text"), + "data": getattr(p, "data", "") or "", + "contentObjectId": getattr(p, "contentObjectId", "") or str(uuid.uuid4()), + "contextRef": getattr(p, "contextRef", {}) or {}, + } + for p in (extracted.parts or []) + if getattr(p, "data", None) + ] + if not contentObjects: + continue + + await knowledgeService.indexFile( + fileId=fid, + fileName=fileName, + mimeType=mimeType, + userId=user.id if user else "", + featureInstanceId=featureInstanceId, + mandateId=mandateId, + contentObjects=contentObjects, + structure=getattr(extracted, "structure", None), + ) + indexed += 1 + except Exception as e: + logger.debug(f"Ensure-embed: skipping file {fid}: {e}") + + if indexed: + logger.info(f"Ensure-embed: indexed {indexed} file(s) before agent start") + return indexed + + async def _deriveWorkflowName(prompt: str, aiService) -> str: """Use AI to generate a concise workflow title from the user prompt.""" from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum @@ -740,20 +892,29 @@ async def _runWorkspaceAgent( priorFileIds = _collectPriorFileIds(chatInterface, workflowId) currentFileIdSet = set(fileIds or []) - mergedFileIds = list(fileIds or []) - for pf in priorFileIds: - if pf not in currentFileIdSet: - mergedFileIds.append(pf) - if len(mergedFileIds) > len(fileIds or []): + candidatePriorIds = [pf for pf in priorFileIds if pf not in currentFileIdSet] + + # Embed-first rule: newly attached files are indexed into the knowledge + # store BEFORE the agent starts, so RAG retrieval works from round 1. + await _ensureFilesIndexed(fileIds or [], user, mandateId, instanceId) + + _cfg = instanceConfig or {} + + if candidatePriorIds: + maxPriorFiles = int(_cfg.get("maxPriorContextFiles", DEFAULT_MAX_PRIOR_CONTEXT_FILES)) + candidatePriorIds = await _selectPriorFilesForContext( + candidatePriorIds, prompt, aiObjects, mandateId, instanceId, maxPriorFiles, + ) + + mergedFileIds = list(fileIds or []) + candidatePriorIds + if candidatePriorIds: logger.info( - f"Merged {len(mergedFileIds) - len(fileIds or [])} prior file(s) into agent context " + f"Merged {len(candidatePriorIds)} prior file(s) into agent context " f"(total: {len(mergedFileIds)}) for workflow {workflowId}" ) accumulatedText = "" messagePersisted = False - - _cfg = instanceConfig or {} _toolSet = _cfg.get("toolSet", "core") _agentCfg = _cfg.get("agentConfig") from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentConfig diff --git a/modules/interfaces/interfaceAiObjects.py b/modules/interfaces/interfaceAiObjects.py index c36e10b6..a1800648 100644 --- a/modules/interfaces/interfaceAiObjects.py +++ b/modules/interfaces/interfaceAiObjects.py @@ -4,7 +4,9 @@ import logging import asyncio import uuid import base64 +import hashlib import json +from collections import OrderedDict from typing import Dict, Any, List, Union, Tuple, Optional, Callable, AsyncGenerator from dataclasses import dataclass, field import time @@ -542,7 +544,27 @@ class AiObjects: else: options.operationType = OperationTypeEnum.EMBEDDING - combinedText = " ".join(texts[:3])[:500] + # Serve known vectors from cache; only unknown texts go to the API. + resolvedVectors: Dict[int, List[float]] = {} + pendingTexts: List[str] = [] + pendingPositions: List[int] = [] + for i, t in enumerate(texts): + cachedVector = _embeddingCacheGet(t) + if cachedVector is not None: + resolvedVectors[i] = cachedVector + else: + pendingTexts.append(t) + pendingPositions.append(i) + + if not pendingTexts: + logger.debug(f"Embedding cache hit for all {len(texts)} text(s)") + return AiCallResponse( + content="", modelName="embedding-cache", priceCHF=0.0, + processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=0, + metadata={"embeddings": [resolvedVectors[i] for i in range(len(texts))]}, + ) + + combinedText = " ".join(pendingTexts[:3])[:500] availableModels = modelRegistry.getAvailableModels() allowedProviders = getattr(options, 'allowedProviders', None) if options else None @@ -575,13 +597,13 @@ class AiObjects: for attempt, model in enumerate(failoverModelList): try: logger.info(f"Embedding call with {model.name} (attempt {attempt + 1}/{len(failoverModelList)})") - inputBytes = sum(len(t.encode("utf-8")) for t in texts) + inputBytes = sum(len(t.encode("utf-8")) for t in pendingTexts) startTime = time.time() - batches = _buildEmbeddingBatches(texts, model.contextLength) + batches = _buildEmbeddingBatches(pendingTexts, model.contextLength) logger.info( - f"Embedding: {len(texts)} texts -> {len(batches)} batch(es), " - f"model contextLength={model.contextLength}" + f"Embedding: {len(pendingTexts)} texts ({len(resolvedVectors)} cached) -> " + f"{len(batches)} batch(es), model contextLength={model.contextLength}" ) allEmbeddings: List[List[float]] = [] @@ -606,11 +628,17 @@ class AiObjects: if totalPriceCHF == 0.0: totalPriceCHF = model.calculatepriceCHF(processingTime, inputBytes, 0) + for j, position in enumerate(pendingPositions): + if j < len(allEmbeddings): + resolvedVectors[position] = allEmbeddings[j] + _embeddingCachePut(pendingTexts[j], allEmbeddings[j]) + mergedEmbeddings = [resolvedVectors.get(i, []) for i in range(len(texts))] + response = AiCallResponse( content="", modelName=model.name, provider=model.connectorType, priceCHF=totalPriceCHF, processingTime=processingTime, bytesSent=inputBytes, bytesReceived=0, errorCount=0, - metadata={"embeddings": allEmbeddings} + metadata={"embeddings": mergedEmbeddings} ) if self.billingCallback: @@ -681,6 +709,28 @@ class AiObjects: _CHARS_PER_TOKEN = 4 _SAFETY_MARGIN = 0.90 +# In-process cache for embedding vectors. Identical texts (e.g. the same user +# prompt embedded once for prior-file selection and once for RAG context +# building) hit the cache instead of paying for a second API call. +_EMBEDDING_CACHE_MAX_ENTRIES = 256 +_embeddingCache: OrderedDict = OrderedDict() + + +def _embeddingCacheGet(text: str) -> Optional[List[float]]: + key = hashlib.sha256(text.encode("utf-8")).hexdigest() + vector = _embeddingCache.get(key) + if vector is not None: + _embeddingCache.move_to_end(key) + return vector + + +def _embeddingCachePut(text: str, vector: List[float]) -> None: + key = hashlib.sha256(text.encode("utf-8")).hexdigest() + _embeddingCache[key] = vector + _embeddingCache.move_to_end(key) + while len(_embeddingCache) > _EMBEDDING_CACHE_MAX_ENTRIES: + _embeddingCache.popitem(last=False) + def _estimateTokens(text: str) -> int: """Rough token estimate: 1 token ~ 4 characters.""" @@ -691,9 +741,7 @@ def _buildEmbeddingBatches(texts: List[str], contextLength: int) -> List[List[st """Split a list of texts into batches whose total estimated token count stays within the model's contextLength (with safety margin). - Each individual text is assumed to already be within limits (enforced by - the chunking layer). If a single text exceeds the budget, it is placed - in its own batch as a last resort. + Texts that individually exceed the per-input limit are truncated to fit. """ if not texts: return [] @@ -701,11 +749,21 @@ def _buildEmbeddingBatches(texts: List[str], contextLength: int) -> List[List[st return [texts] maxTokensPerBatch = int(contextLength * _SAFETY_MARGIN) + maxCharsPerInput = maxTokensPerBatch * _CHARS_PER_TOKEN batches: List[List[str]] = [] currentBatch: List[str] = [] currentTokens = 0 for text in texts: + if len(text) > maxCharsPerInput: + # API hard limit per input. File content never hits this (the + # chunking layer splits at ~400 tokens); only oversized search + # queries can, where truncation is semantically acceptable. + logger.warning( + f"Embedding input truncated from {len(text)} to {maxCharsPerInput} chars " + f"(model input limit {contextLength} tokens)" + ) + text = text[:maxCharsPerInput] textTokens = _estimateTokens(text) if currentBatch and (currentTokens + textTokens) > maxTokensPerBatch: batches.append(currentBatch) diff --git a/modules/interfaces/interfaceDbApp.py b/modules/interfaces/interfaceDbApp.py index 52cd5a59..a3dda3ba 100644 --- a/modules/interfaces/interfaceDbApp.py +++ b/modules/interfaces/interfaceDbApp.py @@ -3112,8 +3112,12 @@ class AppObjects: # Token methods - def saveAccessToken(self, token: Token, replace_existing: bool = True) -> None: - """Save an access token for the current user (must NOT have connectionId)""" + def saveAccessToken(self, token: Token, replace_existing: bool = False) -> None: + """Save an access token for the current user (must NOT have connectionId). + + Multi-session: replace_existing=False (default) keeps existing sessions alive. + Only set replace_existing=True for explicit single-session scenarios. + """ try: # Validate that this is NOT a connection token if token.connectionId: diff --git a/modules/routes/routeAdminSessions.py b/modules/routes/routeAdminSessions.py new file mode 100644 index 00000000..d962a9ea --- /dev/null +++ b/modules/routes/routeAdminSessions.py @@ -0,0 +1,185 @@ +# Copyright (c) 2026 PowerOn AG +# All rights reserved. +""" +Admin endpoints for session and trusted device management. + +Allows mandate-admins and platform-admins to view and revoke active sessions +and trusted devices for users under their jurisdiction. +""" + +from fastapi import APIRouter, HTTPException, status, Depends, Request, Query +from typing import Dict, Any, List +import logging + +from modules.auth import limiter, getCurrentUser +from modules.datamodels.datamodelUam import User +from modules.datamodels.datamodelSecurity import Token, TokenPurpose, TokenStatus, TrustedDevice +from modules.interfaces.interfaceDbApp import getRootInterface +from modules.shared.timeUtils import getUtcTimestamp +from modules.shared.i18nRegistry import apiRouteContext + +routeApiMsg = apiRouteContext("routeAdminSessions") +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/api/admin/sessions", + tags=["Admin Sessions"], + responses={404: {"description": "Not found"}}, +) + + +def _requireAdmin(currentUser: User) -> None: + """Ensure the caller is a platform admin or sysAdmin.""" + if not (getattr(currentUser, "isPlatformAdmin", False) or getattr(currentUser, "isSysAdmin", False)): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=routeApiMsg("Only platform admins can manage sessions"), + ) + + +@router.get("") +@limiter.limit("30/minute") +def listSessions( + request: Request, + userId: str = Query(..., description="User ID whose sessions to list"), + currentUser: User = Depends(getCurrentUser), +) -> List[Dict[str, Any]]: + """List active auth sessions for a user.""" + _requireAdmin(currentUser) + rootInterface = getRootInterface() + + tokens = rootInterface.db.getRecordset( + Token, + recordFilter={ + "userId": userId, + "tokenPurpose": TokenPurpose.AUTH_SESSION.value, + "status": TokenStatus.ACTIVE.value, + }, + ) + + now = getUtcTimestamp() + result = [] + for t in tokens: + expiresAt = t.get("expiresAt", 0) + if expiresAt < now: + continue + result.append({ + "sessionId": t.get("sessionId"), + "tokenId": t.get("id"), + "authority": t.get("authority"), + "createdAt": t.get("sysCreatedAt"), + "expiresAt": expiresAt, + }) + + return result + + +@router.delete("/{sessionId}") +@limiter.limit("30/minute") +def revokeSession( + request: Request, + sessionId: str, + currentUser: User = Depends(getCurrentUser), +) -> Dict[str, Any]: + """Revoke a single session by sessionId.""" + _requireAdmin(currentUser) + rootInterface = getRootInterface() + + tokens = rootInterface.db.getRecordset( + Token, + recordFilter={"sessionId": sessionId, "tokenPurpose": TokenPurpose.AUTH_SESSION.value}, + ) + count = 0 + for t in tokens: + rootInterface.db.recordDelete(Token, t["id"]) + count += 1 + + if count == 0: + raise HTTPException(status_code=404, detail=routeApiMsg("Session not found")) + + logger.info(f"Admin {currentUser.username} revoked session {sessionId} ({count} token(s))") + return {"revoked": count, "sessionId": sessionId} + + +@router.delete("") +@limiter.limit("10/minute") +def revokeAllSessions( + request: Request, + userId: str = Query(..., description="User ID whose sessions to revoke"), + currentUser: User = Depends(getCurrentUser), +) -> Dict[str, Any]: + """Revoke ALL active sessions for a user (force logout everywhere).""" + _requireAdmin(currentUser) + rootInterface = getRootInterface() + + tokens = rootInterface.db.getRecordset( + Token, + recordFilter={ + "userId": userId, + "tokenPurpose": TokenPurpose.AUTH_SESSION.value, + }, + ) + count = 0 + for t in tokens: + rootInterface.db.recordDelete(Token, t["id"]) + count += 1 + + logger.info(f"Admin {currentUser.username} revoked all sessions for userId={userId} ({count} token(s))") + return {"revoked": count, "userId": userId} + + +# --- Trusted Devices --- + +trustedDeviceRouter = APIRouter( + prefix="/api/admin/trusted-devices", + tags=["Admin Sessions"], + responses={404: {"description": "Not found"}}, +) + + +@trustedDeviceRouter.get("") +@limiter.limit("30/minute") +def listTrustedDevices( + request: Request, + userId: str = Query(..., description="User ID whose trusted devices to list"), + currentUser: User = Depends(getCurrentUser), +) -> List[Dict[str, Any]]: + """List trusted devices for a user.""" + _requireAdmin(currentUser) + rootInterface = getRootInterface() + + devices = rootInterface.db.getRecordset( + TrustedDevice, recordFilter={"userId": userId} + ) + + now = getUtcTimestamp() + result = [] + for d in devices: + result.append({ + "id": d.get("id", "")[:8] + "...", + "trustedUntil": d.get("trustedUntil"), + "isExpired": d.get("trustedUntil", 0) < now, + "userAgent": d.get("userAgent"), + "ipAddress": d.get("ipAddress"), + "createdAt": d.get("createdAt"), + }) + + return result + + +@trustedDeviceRouter.delete("") +@limiter.limit("10/minute") +def revokeAllTrustedDevices( + request: Request, + userId: str = Query(..., description="User ID whose trusted devices to revoke"), + currentUser: User = Depends(getCurrentUser), +) -> Dict[str, Any]: + """Revoke ALL trusted devices for a user (force MFA on next login).""" + _requireAdmin(currentUser) + rootInterface = getRootInterface() + + from modules.auth.trustedDeviceService import revokeTrustedDevices + count = revokeTrustedDevices(userId, rootInterface.db) + + logger.info(f"Admin {currentUser.username} revoked all trusted devices for userId={userId} ({count})") + return {"revoked": count, "userId": userId} diff --git a/modules/routes/routeMfa.py b/modules/routes/routeMfa.py index 0d3e4d59..5b13d592 100644 --- a/modules/routes/routeMfa.py +++ b/modules/routes/routeMfa.py @@ -232,6 +232,13 @@ def mfaVerify( logger.info("MFA verify successful for user %s", username) + # Mark device as trusted so MFA is skipped on next login from this device + try: + from modules.auth.trustedDeviceService import createTrustedDevice + createTrustedDevice(userId, request, response, rootInterface.db) + except Exception as e: + logger.warning(f"Failed to create trusted device after MFA verify: {e}") + try: from modules.dbHelpers.auditLogger import audit_logger audit_logger.logUserAccess( diff --git a/modules/routes/routeSecurityLocal.py b/modules/routes/routeSecurityLocal.py index ee2f6390..3107a2b8 100644 --- a/modules/routes/routeSecurityLocal.py +++ b/modules/routes/routeSecurityLocal.py @@ -256,6 +256,7 @@ def login( # --- MFA gate -------------------------------------------------------- from modules.auth.mfaService import isMfaRequired as _isMfaRequired + from modules.auth.trustedDeviceService import isTrustedDevice as _isTrustedDevice from modules.routes.routeMfa import createMfaPendingToken userRecord = rootInterface._getUserForAuthentication(user.username) @@ -273,7 +274,14 @@ def login( mfaRequired = _isMfaRequired(user, userMandates=userMandates, mandates=mandateObjs) hasMfaSetup = bool(userRecord and userRecord.get("mfaSecret") and getattr(user, "mfaEnabled", False)) + # Trusted device: skip MFA if the device was previously verified + _deviceTrusted = False if mfaRequired or hasMfaSetup: + _deviceTrusted = _isTrustedDevice(request, str(user.id), rootInterface.db) + if _deviceTrusted: + logger.info(f"MFA skipped for user {user.username} (trusted device)") + + if (mfaRequired or hasMfaSetup) and not _deviceTrusted: _sid = str(uuid.uuid4()) pendingToken = createMfaPendingToken( userId=str(user.id), @@ -659,31 +667,46 @@ def refresh_token( logger.error(f"Failed to get user from database: {str(e)}") raise HTTPException(status_code=500, detail=routeApiMsg("Failed to validate user")) + # Preserve sessionId from the refresh token so the session stays grouped + sessionId = payload.get("sid") or str(uuid.uuid4()) + # Create new token data - # MULTI-TENANT: Token does NOT contain mandateId anymore + newJti = str(uuid.uuid4()) token_data = { "sub": current_user.username, "userId": str(current_user.id), - "authenticationAuthority": current_user.authenticationAuthority - # NO mandateId in token + "authenticationAuthority": current_user.authenticationAuthority, + "jti": newJti, + "sid": sessionId, } - + # Create new access token + set cookie - access_token, _expires = createAccessToken(token_data) + access_token, accessExpires = createAccessToken(token_data) setAccessTokenCookie(response, access_token) - - # Get expiration time + + # Persist the new token in DB so _getUserBase() accepts it + authority = current_user.authenticationAuthority + if isinstance(authority, str): + authority = AuthAuthority(authority) + dbToken = Token( + id=newJti, + userId=str(current_user.id), + authority=authority, + tokenAccess=access_token, + tokenPurpose=TokenPurpose.AUTH_SESSION, + expiresAt=accessExpires.timestamp(), + sessionId=sessionId, + ) try: - payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM]) - expires_at = datetime.fromtimestamp(payload.get("exp")) + userInterface = getInterface(current_user) + userInterface.saveAccessToken(dbToken) except Exception as e: - logger.error(f"Failed to decode new access token: {str(e)}") - raise HTTPException(status_code=500, detail=routeApiMsg("Failed to create new token")) - + logger.warning(f"Failed to persist refreshed token in DB: {e}") + return { "type": "token_refresh_success", "message": "Token refreshed successfully", - "expires_at": expires_at.isoformat() + "expires_at": accessExpires.isoformat() } except HTTPException as e: diff --git a/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py b/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py index d2c0830b..6ad29488 100644 --- a/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py +++ b/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py @@ -638,7 +638,8 @@ class KnowledgeService: Returns: Formatted context string for injection into the agent's system prompt. """ - queryVector = await self._embedSingle(currentPrompt) + queryText = _extractUserQuery(currentPrompt) + queryVector = await self._embedSingle(queryText) if queryText else [] logger.debug( "buildAgentContext.start userId=%s featureInstanceId=%s mandateId=%s isSysAdmin=%s prompt=%r", userId, featureInstanceId, mandateId, isSysAdmin, (currentPrompt or "")[:120], @@ -960,6 +961,29 @@ class KnowledgeService: # Internal helpers # ============================================================================= +# Markers added by the prompt enrichment layers (_enrichPromptWithFiles in +# mainServiceAgent, data source sections in routeFeatureWorkspace). Used to +# isolate the user's own words for the semantic search query. +_USER_REQUEST_MARKER = "\n\nUser request: " +_PROMPT_SECTION_MARKERS = ("\n\n[Active Data Sources]\n", "\n\n[Attached Feature Data Sources]\n") + + +def _extractUserQuery(prompt: str) -> str: + """Isolate the user's actual question from an enriched agent prompt. + + Enriched prompts wrap the user request in file metadata headers and + data-source sections. Only the user's own words form a meaningful semantic + search query - embedding the metadata would dilute the vector and can + exceed the embedding model's input limit. + """ + if not prompt: + return "" + text = prompt.rsplit(_USER_REQUEST_MARKER, 1)[-1] + for marker in _PROMPT_SECTION_MARKERS: + text = text.split(marker, 1)[0] + return text.strip() + + def _estimateTokens(text: str) -> int: """Estimate token count using character-based heuristic (1 token ~ 4 chars).""" return max(1, len(text) // CHARS_PER_TOKEN)