ux cleanup
This commit is contained in:
parent
535bd43174
commit
a1d9c68604
11 changed files with 721 additions and 33 deletions
171
modules/auth/trustedDeviceService.py
Normal file
171
modules/auth/trustedDeviceService.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright (c) 2026 PowerOn AG
|
||||
# All rights reserved.
|
||||
"""
|
||||
Trusted Device Service.
|
||||
|
||||
After successful MFA verification a device can be marked as trusted for a
|
||||
configurable duration (default 60 days). On subsequent logins from the same
|
||||
device the MFA step is skipped.
|
||||
|
||||
Cookie: ``mfa_trusted`` (httpOnly, Secure, SameSite policy from jwtService).
|
||||
DB: ``TrustedDevice`` table in poweron_app.
|
||||
|
||||
Regulatory basis:
|
||||
- NIST SP 800-63B Section 5.2.8: Verifier MAY re-authenticate only after a
|
||||
configurable period when a device is bound to the subscriber.
|
||||
- Microsoft, Google, AWS implement identical patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request, Response
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.shared.timeUtils import getUtcNow, getUtcTimestamp
|
||||
from modules.datamodels.datamodelSecurity import TrustedDevice
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_COOKIE_NAME = "mfa_trusted"
|
||||
_DEFAULT_TRUST_DAYS = 60
|
||||
_TOKEN_BYTES = 32
|
||||
|
||||
|
||||
def _getTrustDurationDays() -> int:
|
||||
raw = (APP_CONFIG.get("MFA_TRUST_DURATION_DAYS") or "").strip()
|
||||
if raw.isdigit() and int(raw) > 0:
|
||||
return int(raw)
|
||||
return _DEFAULT_TRUST_DAYS
|
||||
|
||||
|
||||
def createTrustedDevice(userId: str, request: Request, response: Response, db) -> str:
|
||||
"""Create a TrustedDevice entry and set the cookie on the response.
|
||||
|
||||
Returns the device token (cookie value).
|
||||
"""
|
||||
from modules.auth.jwtService import _cookiePolicy
|
||||
|
||||
trustDays = _getTrustDurationDays()
|
||||
deviceToken = secrets.token_urlsafe(_TOKEN_BYTES)
|
||||
|
||||
now = getUtcTimestamp()
|
||||
trustedUntil = now + (trustDays * 86400)
|
||||
|
||||
device = TrustedDevice(
|
||||
id=deviceToken,
|
||||
userId=userId,
|
||||
trustedUntil=trustedUntil,
|
||||
userAgent=(request.headers.get("user-agent") or "")[:512],
|
||||
ipAddress=_getClientIp(request),
|
||||
createdAt=now,
|
||||
)
|
||||
|
||||
try:
|
||||
db.recordCreate(TrustedDevice, device.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to persist TrustedDevice for userId={userId}: {e}")
|
||||
return ""
|
||||
|
||||
useSecure, samesite, _ = _cookiePolicy()
|
||||
response.set_cookie(
|
||||
key=_COOKIE_NAME,
|
||||
value=deviceToken,
|
||||
httponly=True,
|
||||
secure=useSecure,
|
||||
samesite=samesite,
|
||||
path="/",
|
||||
max_age=trustDays * 86400,
|
||||
)
|
||||
|
||||
logger.info(f"Trusted device created for userId={userId}, valid {trustDays}d")
|
||||
return deviceToken
|
||||
|
||||
|
||||
def isTrustedDevice(request: Request, userId: str, db) -> bool:
|
||||
"""Check if the current request comes from a trusted device for the given user."""
|
||||
deviceToken = request.cookies.get(_COOKIE_NAME)
|
||||
if not deviceToken:
|
||||
return False
|
||||
|
||||
try:
|
||||
records = db.getRecordset(
|
||||
TrustedDevice,
|
||||
recordFilter={"id": deviceToken, "userId": userId},
|
||||
)
|
||||
if not records:
|
||||
return False
|
||||
|
||||
device = records[0]
|
||||
trustedUntil = device.get("trustedUntil", 0)
|
||||
if isinstance(trustedUntil, (int, float)) and trustedUntil > getUtcTimestamp():
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking trusted device for userId={userId}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def revokeTrustedDevices(userId: str, db) -> int:
|
||||
"""Revoke all trusted devices for a user. Returns count of deleted entries."""
|
||||
try:
|
||||
records = db.getRecordset(TrustedDevice, recordFilter={"userId": userId})
|
||||
count = 0
|
||||
for rec in records:
|
||||
db.recordDelete(TrustedDevice, rec["id"])
|
||||
count += 1
|
||||
if count:
|
||||
logger.info(f"Revoked {count} trusted device(s) for userId={userId}")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke trusted devices for userId={userId}: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def clearTrustedDeviceCookie(response: Response) -> None:
|
||||
"""Clear the mfa_trusted cookie."""
|
||||
from modules.auth.jwtService import _cookiePolicy
|
||||
|
||||
useSecure, samesite, samesiteHeader = _cookiePolicy()
|
||||
secure_flag = "; Secure" if useSecure else ""
|
||||
response.headers.append(
|
||||
"Set-Cookie",
|
||||
f"{_COOKIE_NAME}=deleted; Path=/; Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly{secure_flag}; SameSite={samesiteHeader}"
|
||||
)
|
||||
response.delete_cookie(
|
||||
key=_COOKIE_NAME,
|
||||
path="/",
|
||||
secure=useSecure,
|
||||
httponly=True,
|
||||
samesite=samesite,
|
||||
)
|
||||
|
||||
|
||||
def cleanupExpiredDevices(db) -> int:
|
||||
"""Remove TrustedDevice entries past their trustedUntil. Returns deleted count."""
|
||||
try:
|
||||
records = db.getRecordset(TrustedDevice, recordFilter={})
|
||||
now = getUtcTimestamp()
|
||||
count = 0
|
||||
for rec in records:
|
||||
if rec.get("trustedUntil", 0) < now:
|
||||
db.recordDelete(TrustedDevice, rec["id"])
|
||||
count += 1
|
||||
if count:
|
||||
logger.info(f"Cleaned up {count} expired trusted device(s)")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up expired trusted devices: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def _getClientIp(request: Request) -> Optional[str]:
|
||||
"""Extract client IP from request (respects X-Forwarded-For)."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return None
|
||||
|
|
@ -871,6 +871,7 @@ class DatabaseConnector:
|
|||
("jsonb", "TEXT"): "TEXT USING \"{col}\"::text",
|
||||
("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE,
|
||||
("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer",
|
||||
("text", "BOOLEAN"): "BOOLEAN USING CASE WHEN \"{col}\" IN ('true', '1', 't', 'yes') THEN TRUE ELSE FALSE END",
|
||||
("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')',
|
||||
("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")',
|
||||
("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')',
|
||||
|
|
|
|||
|
|
@ -19,6 +19,21 @@ from modules.shared.voiceCatalog import getDefaultVoice
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_STT_LANGUAGE_MAP = {
|
||||
"de-CH": "de-DE",
|
||||
"de-AT": "de-DE",
|
||||
"en-GB": "en-US",
|
||||
"en-AU": "en-US",
|
||||
"fr-CH": "fr-FR",
|
||||
"it-CH": "it-IT",
|
||||
"pt-BR": "pt-PT",
|
||||
}
|
||||
|
||||
|
||||
def _normalizeSttLanguage(language: str) -> str:
|
||||
"""Map regional language variants to codes supported by Google STT models."""
|
||||
return _STT_LANGUAGE_MAP.get(language, language)
|
||||
|
||||
|
||||
def _buildPrimarySttRecognitionFields(
|
||||
*,
|
||||
|
|
@ -116,6 +131,7 @@ class ConnectorGoogleSpeech:
|
|||
Returns:
|
||||
Dict containing transcribed text, confidence, and metadata
|
||||
"""
|
||||
language = _normalizeSttLanguage(language)
|
||||
try:
|
||||
# Treat sampleRate=0 as unknown (invalid value from client)
|
||||
if sampleRate is not None and sampleRate <= 0:
|
||||
|
|
@ -480,6 +496,7 @@ class ConnectorGoogleSpeech:
|
|||
Dicts with keys: isFinal, transcript, confidence, stabilityScore, audioDurationSec;
|
||||
optionally endOfSingleUtterance, reconnectRequired
|
||||
"""
|
||||
language = _normalizeSttLanguage(language)
|
||||
STREAM_LIMIT_SEC = 290
|
||||
streamStartTs = time.time()
|
||||
totalAudioBytes = 0
|
||||
|
|
|
|||
|
|
@ -124,6 +124,43 @@ class Token(PowerOnModel):
|
|||
return data
|
||||
|
||||
|
||||
@i18nModel("Vertrauenswuerdiges Geraet")
|
||||
class TrustedDevice(PowerOnModel):
|
||||
"""A device trusted after successful MFA verification (skips MFA for configured duration)."""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Random token stored as httpOnly cookie value",
|
||||
json_schema_extra={"label": "ID"},
|
||||
)
|
||||
userId: str = Field(
|
||||
...,
|
||||
description="User this trusted device belongs to",
|
||||
json_schema_extra={"label": "Benutzer-ID", "fk_target": {"db": "poweron_app", "table": "UserInDB", "labelField": "username"}},
|
||||
)
|
||||
trustedUntil: float = Field(
|
||||
...,
|
||||
description="UTC timestamp until which the device is trusted",
|
||||
json_schema_extra={"label": "Vertrauenswuerdig bis", "frontend_type": "timestamp"},
|
||||
)
|
||||
userAgent: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Browser user agent at time of trust grant",
|
||||
json_schema_extra={"label": "User-Agent"},
|
||||
)
|
||||
ipAddress: Optional[str] = Field(
|
||||
default=None,
|
||||
description="IP address at time of trust grant",
|
||||
json_schema_extra={"label": "IP-Adresse"},
|
||||
)
|
||||
createdAt: float = Field(
|
||||
default_factory=getUtcTimestamp,
|
||||
description="When the device was trusted",
|
||||
json_schema_extra={"label": "Erstellt am", "frontend_type": "timestamp"},
|
||||
)
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
@i18nModel("Authentifizierungsereignis")
|
||||
class AuthEvent(PowerOnModel):
|
||||
"""Authentication event for audit logging."""
|
||||
|
|
|
|||
|
|
@ -489,6 +489,158 @@ def _collectPriorFileIds(chatInterface, workflowId: str) -> List[str]:
|
|||
return result
|
||||
|
||||
|
||||
# Default context budget for prior files (metadata injection per file costs tokens
|
||||
# in EVERY agent round). Override per workspace instance via instanceConfig
|
||||
# key "maxPriorContextFiles". Files outside the budget remain fully accessible
|
||||
# to the agent via conversation history and readFile - nothing is lost.
|
||||
DEFAULT_MAX_PRIOR_CONTEXT_FILES = 10
|
||||
|
||||
|
||||
async def _selectPriorFilesForContext(
|
||||
priorFileIds: List[str],
|
||||
prompt: str,
|
||||
aiObjects,
|
||||
mandateId: str,
|
||||
featureInstanceId: str,
|
||||
maxFiles: int,
|
||||
) -> List[str]:
|
||||
"""Select which prior files get their metadata injected into the agent context.
|
||||
|
||||
Selection only happens when more candidates exist than the budget allows -
|
||||
otherwise all files pass through untouched. Selection order:
|
||||
1. Files whose indexed chunks are semantically relevant to the current prompt
|
||||
2. Remaining budget filled with the most recent files (so files without an
|
||||
index entry are never unfairly dropped)
|
||||
"""
|
||||
if len(priorFileIds) <= maxFiles:
|
||||
return priorFileIds
|
||||
|
||||
rankedFileIds: List[str] = []
|
||||
try:
|
||||
embeddingResponse = await aiObjects.callEmbedding([prompt])
|
||||
embeddings = (embeddingResponse.metadata or {}).get("embeddings", [])
|
||||
if embeddings:
|
||||
knowledgeIf = getKnowledgeInterface()
|
||||
results = knowledgeIf.semanticSearch(
|
||||
queryVector=embeddings[0],
|
||||
featureInstanceId=featureInstanceId,
|
||||
mandateId=mandateId,
|
||||
limit=maxFiles * 3,
|
||||
)
|
||||
priorSet = set(priorFileIds)
|
||||
for chunk in results:
|
||||
fid = chunk.get("fileId") if isinstance(chunk, dict) else getattr(chunk, "fileId", None)
|
||||
if fid and fid in priorSet and fid not in rankedFileIds:
|
||||
rankedFileIds.append(fid)
|
||||
if len(rankedFileIds) >= maxFiles:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Relevance ranking for prior files failed: {e}")
|
||||
|
||||
for fid in reversed(priorFileIds):
|
||||
if len(rankedFileIds) >= maxFiles:
|
||||
break
|
||||
if fid not in rankedFileIds:
|
||||
rankedFileIds.append(fid)
|
||||
|
||||
logger.info(
|
||||
f"Prior-file selection: {len(priorFileIds)} candidates -> {len(rankedFileIds)} in context "
|
||||
f"(budget={maxFiles}, relevance-ranked, recency fill)"
|
||||
)
|
||||
return rankedFileIds
|
||||
|
||||
|
||||
async def _ensureFilesIndexed(
|
||||
fileIds: List[str],
|
||||
user,
|
||||
mandateId: str,
|
||||
featureInstanceId: str,
|
||||
) -> int:
|
||||
"""Ensure all attached files have embeddings in the knowledge store.
|
||||
|
||||
Checks FileContentIndex for each file. Files not yet indexed are extracted
|
||||
and embedded inline so that subsequent RAG queries can find their content.
|
||||
Indexing is idempotent (content-hash check in requestIngestion), so each
|
||||
file is only ever processed once - re-attaching an indexed file is a no-op.
|
||||
|
||||
Returns the number of files that were newly indexed.
|
||||
"""
|
||||
if not fileIds:
|
||||
return 0
|
||||
|
||||
knowledgeIf = getKnowledgeInterface(user)
|
||||
unindexedIds = []
|
||||
for fid in fileIds:
|
||||
existing = knowledgeIf.getFileContentIndex(fid)
|
||||
status = (existing.get("status") if isinstance(existing, dict) else getattr(existing, "status", "")) if existing else ""
|
||||
if status not in ("indexed", "embedding"):
|
||||
unindexedIds.append(fid)
|
||||
|
||||
if not unindexedIds:
|
||||
return 0
|
||||
|
||||
logger.info(f"Ensure-embed: {len(unindexedIds)}/{len(fileIds)} files need indexing")
|
||||
|
||||
from modules.serviceCenter import getService
|
||||
from modules.serviceCenter.context import ServiceCenterContext
|
||||
indexCtx = ServiceCenterContext(
|
||||
user=user, mandateId=mandateId, featureInstanceId=featureInstanceId,
|
||||
)
|
||||
try:
|
||||
knowledgeService = getService("knowledge", indexCtx)
|
||||
except Exception as e:
|
||||
logger.warning(f"Ensure-embed: knowledge service unavailable: {e}")
|
||||
return 0
|
||||
|
||||
chatInterface = interfaceDbChat.getInterface(user)
|
||||
indexed = 0
|
||||
for fid in unindexedIds:
|
||||
try:
|
||||
fileInfo = chatInterface.getFileInfo(fid) if chatInterface else None
|
||||
if not fileInfo:
|
||||
continue
|
||||
fileName = fileInfo.get("fileName", "")
|
||||
mimeType = fileInfo.get("mimeType", "")
|
||||
rawBytes = chatInterface.getFileData(fid)
|
||||
if not rawBytes:
|
||||
continue
|
||||
|
||||
extractionService = getService("extraction", indexCtx)
|
||||
extracted = extractionService.extractContentFromBytes(
|
||||
rawBytes, fileName, mimeType, documentId=fid,
|
||||
)
|
||||
contentObjects = [
|
||||
{
|
||||
"contentType": getattr(p, "contentType", "text"),
|
||||
"data": getattr(p, "data", "") or "",
|
||||
"contentObjectId": getattr(p, "contentObjectId", "") or str(uuid.uuid4()),
|
||||
"contextRef": getattr(p, "contextRef", {}) or {},
|
||||
}
|
||||
for p in (extracted.parts or [])
|
||||
if getattr(p, "data", None)
|
||||
]
|
||||
if not contentObjects:
|
||||
continue
|
||||
|
||||
await knowledgeService.indexFile(
|
||||
fileId=fid,
|
||||
fileName=fileName,
|
||||
mimeType=mimeType,
|
||||
userId=user.id if user else "",
|
||||
featureInstanceId=featureInstanceId,
|
||||
mandateId=mandateId,
|
||||
contentObjects=contentObjects,
|
||||
structure=getattr(extracted, "structure", None),
|
||||
)
|
||||
indexed += 1
|
||||
except Exception as e:
|
||||
logger.debug(f"Ensure-embed: skipping file {fid}: {e}")
|
||||
|
||||
if indexed:
|
||||
logger.info(f"Ensure-embed: indexed {indexed} file(s) before agent start")
|
||||
return indexed
|
||||
|
||||
|
||||
async def _deriveWorkflowName(prompt: str, aiService) -> str:
|
||||
"""Use AI to generate a concise workflow title from the user prompt."""
|
||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum
|
||||
|
|
@ -740,20 +892,29 @@ async def _runWorkspaceAgent(
|
|||
|
||||
priorFileIds = _collectPriorFileIds(chatInterface, workflowId)
|
||||
currentFileIdSet = set(fileIds or [])
|
||||
mergedFileIds = list(fileIds or [])
|
||||
for pf in priorFileIds:
|
||||
if pf not in currentFileIdSet:
|
||||
mergedFileIds.append(pf)
|
||||
if len(mergedFileIds) > len(fileIds or []):
|
||||
candidatePriorIds = [pf for pf in priorFileIds if pf not in currentFileIdSet]
|
||||
|
||||
# Embed-first rule: newly attached files are indexed into the knowledge
|
||||
# store BEFORE the agent starts, so RAG retrieval works from round 1.
|
||||
await _ensureFilesIndexed(fileIds or [], user, mandateId, instanceId)
|
||||
|
||||
_cfg = instanceConfig or {}
|
||||
|
||||
if candidatePriorIds:
|
||||
maxPriorFiles = int(_cfg.get("maxPriorContextFiles", DEFAULT_MAX_PRIOR_CONTEXT_FILES))
|
||||
candidatePriorIds = await _selectPriorFilesForContext(
|
||||
candidatePriorIds, prompt, aiObjects, mandateId, instanceId, maxPriorFiles,
|
||||
)
|
||||
|
||||
mergedFileIds = list(fileIds or []) + candidatePriorIds
|
||||
if candidatePriorIds:
|
||||
logger.info(
|
||||
f"Merged {len(mergedFileIds) - len(fileIds or [])} prior file(s) into agent context "
|
||||
f"Merged {len(candidatePriorIds)} prior file(s) into agent context "
|
||||
f"(total: {len(mergedFileIds)}) for workflow {workflowId}"
|
||||
)
|
||||
|
||||
accumulatedText = ""
|
||||
messagePersisted = False
|
||||
|
||||
_cfg = instanceConfig or {}
|
||||
_toolSet = _cfg.get("toolSet", "core")
|
||||
_agentCfg = _cfg.get("agentConfig")
|
||||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentConfig
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ import logging
|
|||
import asyncio
|
||||
import uuid
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Any, List, Union, Tuple, Optional, Callable, AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
import time
|
||||
|
|
@ -542,7 +544,27 @@ class AiObjects:
|
|||
else:
|
||||
options.operationType = OperationTypeEnum.EMBEDDING
|
||||
|
||||
combinedText = " ".join(texts[:3])[:500]
|
||||
# Serve known vectors from cache; only unknown texts go to the API.
|
||||
resolvedVectors: Dict[int, List[float]] = {}
|
||||
pendingTexts: List[str] = []
|
||||
pendingPositions: List[int] = []
|
||||
for i, t in enumerate(texts):
|
||||
cachedVector = _embeddingCacheGet(t)
|
||||
if cachedVector is not None:
|
||||
resolvedVectors[i] = cachedVector
|
||||
else:
|
||||
pendingTexts.append(t)
|
||||
pendingPositions.append(i)
|
||||
|
||||
if not pendingTexts:
|
||||
logger.debug(f"Embedding cache hit for all {len(texts)} text(s)")
|
||||
return AiCallResponse(
|
||||
content="", modelName="embedding-cache", priceCHF=0.0,
|
||||
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=0,
|
||||
metadata={"embeddings": [resolvedVectors[i] for i in range(len(texts))]},
|
||||
)
|
||||
|
||||
combinedText = " ".join(pendingTexts[:3])[:500]
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
|
||||
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
||||
|
|
@ -575,13 +597,13 @@ class AiObjects:
|
|||
for attempt, model in enumerate(failoverModelList):
|
||||
try:
|
||||
logger.info(f"Embedding call with {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
|
||||
inputBytes = sum(len(t.encode("utf-8")) for t in texts)
|
||||
inputBytes = sum(len(t.encode("utf-8")) for t in pendingTexts)
|
||||
startTime = time.time()
|
||||
|
||||
batches = _buildEmbeddingBatches(texts, model.contextLength)
|
||||
batches = _buildEmbeddingBatches(pendingTexts, model.contextLength)
|
||||
logger.info(
|
||||
f"Embedding: {len(texts)} texts -> {len(batches)} batch(es), "
|
||||
f"model contextLength={model.contextLength}"
|
||||
f"Embedding: {len(pendingTexts)} texts ({len(resolvedVectors)} cached) -> "
|
||||
f"{len(batches)} batch(es), model contextLength={model.contextLength}"
|
||||
)
|
||||
|
||||
allEmbeddings: List[List[float]] = []
|
||||
|
|
@ -606,11 +628,17 @@ class AiObjects:
|
|||
if totalPriceCHF == 0.0:
|
||||
totalPriceCHF = model.calculatepriceCHF(processingTime, inputBytes, 0)
|
||||
|
||||
for j, position in enumerate(pendingPositions):
|
||||
if j < len(allEmbeddings):
|
||||
resolvedVectors[position] = allEmbeddings[j]
|
||||
_embeddingCachePut(pendingTexts[j], allEmbeddings[j])
|
||||
mergedEmbeddings = [resolvedVectors.get(i, []) for i in range(len(texts))]
|
||||
|
||||
response = AiCallResponse(
|
||||
content="", modelName=model.name, provider=model.connectorType,
|
||||
priceCHF=totalPriceCHF, processingTime=processingTime,
|
||||
bytesSent=inputBytes, bytesReceived=0, errorCount=0,
|
||||
metadata={"embeddings": allEmbeddings}
|
||||
metadata={"embeddings": mergedEmbeddings}
|
||||
)
|
||||
|
||||
if self.billingCallback:
|
||||
|
|
@ -681,6 +709,28 @@ class AiObjects:
|
|||
_CHARS_PER_TOKEN = 4
|
||||
_SAFETY_MARGIN = 0.90
|
||||
|
||||
# In-process cache for embedding vectors. Identical texts (e.g. the same user
|
||||
# prompt embedded once for prior-file selection and once for RAG context
|
||||
# building) hit the cache instead of paying for a second API call.
|
||||
_EMBEDDING_CACHE_MAX_ENTRIES = 256
|
||||
_embeddingCache: OrderedDict = OrderedDict()
|
||||
|
||||
|
||||
def _embeddingCacheGet(text: str) -> Optional[List[float]]:
|
||||
key = hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
vector = _embeddingCache.get(key)
|
||||
if vector is not None:
|
||||
_embeddingCache.move_to_end(key)
|
||||
return vector
|
||||
|
||||
|
||||
def _embeddingCachePut(text: str, vector: List[float]) -> None:
|
||||
key = hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
_embeddingCache[key] = vector
|
||||
_embeddingCache.move_to_end(key)
|
||||
while len(_embeddingCache) > _EMBEDDING_CACHE_MAX_ENTRIES:
|
||||
_embeddingCache.popitem(last=False)
|
||||
|
||||
|
||||
def _estimateTokens(text: str) -> int:
|
||||
"""Rough token estimate: 1 token ~ 4 characters."""
|
||||
|
|
@ -691,9 +741,7 @@ def _buildEmbeddingBatches(texts: List[str], contextLength: int) -> List[List[st
|
|||
"""Split a list of texts into batches whose total estimated token count
|
||||
stays within the model's contextLength (with safety margin).
|
||||
|
||||
Each individual text is assumed to already be within limits (enforced by
|
||||
the chunking layer). If a single text exceeds the budget, it is placed
|
||||
in its own batch as a last resort.
|
||||
Texts that individually exceed the per-input limit are truncated to fit.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
|
@ -701,11 +749,21 @@ def _buildEmbeddingBatches(texts: List[str], contextLength: int) -> List[List[st
|
|||
return [texts]
|
||||
|
||||
maxTokensPerBatch = int(contextLength * _SAFETY_MARGIN)
|
||||
maxCharsPerInput = maxTokensPerBatch * _CHARS_PER_TOKEN
|
||||
batches: List[List[str]] = []
|
||||
currentBatch: List[str] = []
|
||||
currentTokens = 0
|
||||
|
||||
for text in texts:
|
||||
if len(text) > maxCharsPerInput:
|
||||
# API hard limit per input. File content never hits this (the
|
||||
# chunking layer splits at ~400 tokens); only oversized search
|
||||
# queries can, where truncation is semantically acceptable.
|
||||
logger.warning(
|
||||
f"Embedding input truncated from {len(text)} to {maxCharsPerInput} chars "
|
||||
f"(model input limit {contextLength} tokens)"
|
||||
)
|
||||
text = text[:maxCharsPerInput]
|
||||
textTokens = _estimateTokens(text)
|
||||
if currentBatch and (currentTokens + textTokens) > maxTokensPerBatch:
|
||||
batches.append(currentBatch)
|
||||
|
|
|
|||
|
|
@ -3112,8 +3112,12 @@ class AppObjects:
|
|||
|
||||
# Token methods
|
||||
|
||||
def saveAccessToken(self, token: Token, replace_existing: bool = True) -> None:
|
||||
"""Save an access token for the current user (must NOT have connectionId)"""
|
||||
def saveAccessToken(self, token: Token, replace_existing: bool = False) -> None:
|
||||
"""Save an access token for the current user (must NOT have connectionId).
|
||||
|
||||
Multi-session: replace_existing=False (default) keeps existing sessions alive.
|
||||
Only set replace_existing=True for explicit single-session scenarios.
|
||||
"""
|
||||
try:
|
||||
# Validate that this is NOT a connection token
|
||||
if token.connectionId:
|
||||
|
|
|
|||
185
modules/routes/routeAdminSessions.py
Normal file
185
modules/routes/routeAdminSessions.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
# Copyright (c) 2026 PowerOn AG
|
||||
# All rights reserved.
|
||||
"""
|
||||
Admin endpoints for session and trusted device management.
|
||||
|
||||
Allows mandate-admins and platform-admins to view and revoke active sessions
|
||||
and trusted devices for users under their jurisdiction.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, Request, Query
|
||||
from typing import Dict, Any, List
|
||||
import logging
|
||||
|
||||
from modules.auth import limiter, getCurrentUser
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelSecurity import Token, TokenPurpose, TokenStatus, TrustedDevice
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
from modules.shared.i18nRegistry import apiRouteContext
|
||||
|
||||
routeApiMsg = apiRouteContext("routeAdminSessions")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/admin/sessions",
|
||||
tags=["Admin Sessions"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
def _requireAdmin(currentUser: User) -> None:
|
||||
"""Ensure the caller is a platform admin or sysAdmin."""
|
||||
if not (getattr(currentUser, "isPlatformAdmin", False) or getattr(currentUser, "isSysAdmin", False)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=routeApiMsg("Only platform admins can manage sessions"),
|
||||
)
|
||||
|
||||
|
||||
@router.get("")
|
||||
@limiter.limit("30/minute")
|
||||
def listSessions(
|
||||
request: Request,
|
||||
userId: str = Query(..., description="User ID whose sessions to list"),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List active auth sessions for a user."""
|
||||
_requireAdmin(currentUser)
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
tokens = rootInterface.db.getRecordset(
|
||||
Token,
|
||||
recordFilter={
|
||||
"userId": userId,
|
||||
"tokenPurpose": TokenPurpose.AUTH_SESSION.value,
|
||||
"status": TokenStatus.ACTIVE.value,
|
||||
},
|
||||
)
|
||||
|
||||
now = getUtcTimestamp()
|
||||
result = []
|
||||
for t in tokens:
|
||||
expiresAt = t.get("expiresAt", 0)
|
||||
if expiresAt < now:
|
||||
continue
|
||||
result.append({
|
||||
"sessionId": t.get("sessionId"),
|
||||
"tokenId": t.get("id"),
|
||||
"authority": t.get("authority"),
|
||||
"createdAt": t.get("sysCreatedAt"),
|
||||
"expiresAt": expiresAt,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/{sessionId}")
|
||||
@limiter.limit("30/minute")
|
||||
def revokeSession(
|
||||
request: Request,
|
||||
sessionId: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
) -> Dict[str, Any]:
|
||||
"""Revoke a single session by sessionId."""
|
||||
_requireAdmin(currentUser)
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
tokens = rootInterface.db.getRecordset(
|
||||
Token,
|
||||
recordFilter={"sessionId": sessionId, "tokenPurpose": TokenPurpose.AUTH_SESSION.value},
|
||||
)
|
||||
count = 0
|
||||
for t in tokens:
|
||||
rootInterface.db.recordDelete(Token, t["id"])
|
||||
count += 1
|
||||
|
||||
if count == 0:
|
||||
raise HTTPException(status_code=404, detail=routeApiMsg("Session not found"))
|
||||
|
||||
logger.info(f"Admin {currentUser.username} revoked session {sessionId} ({count} token(s))")
|
||||
return {"revoked": count, "sessionId": sessionId}
|
||||
|
||||
|
||||
@router.delete("")
|
||||
@limiter.limit("10/minute")
|
||||
def revokeAllSessions(
|
||||
request: Request,
|
||||
userId: str = Query(..., description="User ID whose sessions to revoke"),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
) -> Dict[str, Any]:
|
||||
"""Revoke ALL active sessions for a user (force logout everywhere)."""
|
||||
_requireAdmin(currentUser)
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
tokens = rootInterface.db.getRecordset(
|
||||
Token,
|
||||
recordFilter={
|
||||
"userId": userId,
|
||||
"tokenPurpose": TokenPurpose.AUTH_SESSION.value,
|
||||
},
|
||||
)
|
||||
count = 0
|
||||
for t in tokens:
|
||||
rootInterface.db.recordDelete(Token, t["id"])
|
||||
count += 1
|
||||
|
||||
logger.info(f"Admin {currentUser.username} revoked all sessions for userId={userId} ({count} token(s))")
|
||||
return {"revoked": count, "userId": userId}
|
||||
|
||||
|
||||
# --- Trusted Devices ---
|
||||
|
||||
trustedDeviceRouter = APIRouter(
|
||||
prefix="/api/admin/trusted-devices",
|
||||
tags=["Admin Sessions"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
@trustedDeviceRouter.get("")
|
||||
@limiter.limit("30/minute")
|
||||
def listTrustedDevices(
|
||||
request: Request,
|
||||
userId: str = Query(..., description="User ID whose trusted devices to list"),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List trusted devices for a user."""
|
||||
_requireAdmin(currentUser)
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
devices = rootInterface.db.getRecordset(
|
||||
TrustedDevice, recordFilter={"userId": userId}
|
||||
)
|
||||
|
||||
now = getUtcTimestamp()
|
||||
result = []
|
||||
for d in devices:
|
||||
result.append({
|
||||
"id": d.get("id", "")[:8] + "...",
|
||||
"trustedUntil": d.get("trustedUntil"),
|
||||
"isExpired": d.get("trustedUntil", 0) < now,
|
||||
"userAgent": d.get("userAgent"),
|
||||
"ipAddress": d.get("ipAddress"),
|
||||
"createdAt": d.get("createdAt"),
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@trustedDeviceRouter.delete("")
|
||||
@limiter.limit("10/minute")
|
||||
def revokeAllTrustedDevices(
|
||||
request: Request,
|
||||
userId: str = Query(..., description="User ID whose trusted devices to revoke"),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
) -> Dict[str, Any]:
|
||||
"""Revoke ALL trusted devices for a user (force MFA on next login)."""
|
||||
_requireAdmin(currentUser)
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
from modules.auth.trustedDeviceService import revokeTrustedDevices
|
||||
count = revokeTrustedDevices(userId, rootInterface.db)
|
||||
|
||||
logger.info(f"Admin {currentUser.username} revoked all trusted devices for userId={userId} ({count})")
|
||||
return {"revoked": count, "userId": userId}
|
||||
|
|
@ -232,6 +232,13 @@ def mfaVerify(
|
|||
|
||||
logger.info("MFA verify successful for user %s", username)
|
||||
|
||||
# Mark device as trusted so MFA is skipped on next login from this device
|
||||
try:
|
||||
from modules.auth.trustedDeviceService import createTrustedDevice
|
||||
createTrustedDevice(userId, request, response, rootInterface.db)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create trusted device after MFA verify: {e}")
|
||||
|
||||
try:
|
||||
from modules.dbHelpers.auditLogger import audit_logger
|
||||
audit_logger.logUserAccess(
|
||||
|
|
|
|||
|
|
@ -256,6 +256,7 @@ def login(
|
|||
|
||||
# --- MFA gate --------------------------------------------------------
|
||||
from modules.auth.mfaService import isMfaRequired as _isMfaRequired
|
||||
from modules.auth.trustedDeviceService import isTrustedDevice as _isTrustedDevice
|
||||
from modules.routes.routeMfa import createMfaPendingToken
|
||||
|
||||
userRecord = rootInterface._getUserForAuthentication(user.username)
|
||||
|
|
@ -273,7 +274,14 @@ def login(
|
|||
mfaRequired = _isMfaRequired(user, userMandates=userMandates, mandates=mandateObjs)
|
||||
hasMfaSetup = bool(userRecord and userRecord.get("mfaSecret") and getattr(user, "mfaEnabled", False))
|
||||
|
||||
# Trusted device: skip MFA if the device was previously verified
|
||||
_deviceTrusted = False
|
||||
if mfaRequired or hasMfaSetup:
|
||||
_deviceTrusted = _isTrustedDevice(request, str(user.id), rootInterface.db)
|
||||
if _deviceTrusted:
|
||||
logger.info(f"MFA skipped for user {user.username} (trusted device)")
|
||||
|
||||
if (mfaRequired or hasMfaSetup) and not _deviceTrusted:
|
||||
_sid = str(uuid.uuid4())
|
||||
pendingToken = createMfaPendingToken(
|
||||
userId=str(user.id),
|
||||
|
|
@ -659,31 +667,46 @@ def refresh_token(
|
|||
logger.error(f"Failed to get user from database: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=routeApiMsg("Failed to validate user"))
|
||||
|
||||
# Preserve sessionId from the refresh token so the session stays grouped
|
||||
sessionId = payload.get("sid") or str(uuid.uuid4())
|
||||
|
||||
# Create new token data
|
||||
# MULTI-TENANT: Token does NOT contain mandateId anymore
|
||||
newJti = str(uuid.uuid4())
|
||||
token_data = {
|
||||
"sub": current_user.username,
|
||||
"userId": str(current_user.id),
|
||||
"authenticationAuthority": current_user.authenticationAuthority
|
||||
# NO mandateId in token
|
||||
"authenticationAuthority": current_user.authenticationAuthority,
|
||||
"jti": newJti,
|
||||
"sid": sessionId,
|
||||
}
|
||||
|
||||
# Create new access token + set cookie
|
||||
access_token, _expires = createAccessToken(token_data)
|
||||
access_token, accessExpires = createAccessToken(token_data)
|
||||
setAccessTokenCookie(response, access_token)
|
||||
|
||||
# Get expiration time
|
||||
# Persist the new token in DB so _getUserBase() accepts it
|
||||
authority = current_user.authenticationAuthority
|
||||
if isinstance(authority, str):
|
||||
authority = AuthAuthority(authority)
|
||||
dbToken = Token(
|
||||
id=newJti,
|
||||
userId=str(current_user.id),
|
||||
authority=authority,
|
||||
tokenAccess=access_token,
|
||||
tokenPurpose=TokenPurpose.AUTH_SESSION,
|
||||
expiresAt=accessExpires.timestamp(),
|
||||
sessionId=sessionId,
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
||||
userInterface = getInterface(current_user)
|
||||
userInterface.saveAccessToken(dbToken)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode new access token: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=routeApiMsg("Failed to create new token"))
|
||||
logger.warning(f"Failed to persist refreshed token in DB: {e}")
|
||||
|
||||
return {
|
||||
"type": "token_refresh_success",
|
||||
"message": "Token refreshed successfully",
|
||||
"expires_at": expires_at.isoformat()
|
||||
"expires_at": accessExpires.isoformat()
|
||||
}
|
||||
|
||||
except HTTPException as e:
|
||||
|
|
|
|||
|
|
@ -638,7 +638,8 @@ class KnowledgeService:
|
|||
Returns:
|
||||
Formatted context string for injection into the agent's system prompt.
|
||||
"""
|
||||
queryVector = await self._embedSingle(currentPrompt)
|
||||
queryText = _extractUserQuery(currentPrompt)
|
||||
queryVector = await self._embedSingle(queryText) if queryText else []
|
||||
logger.debug(
|
||||
"buildAgentContext.start userId=%s featureInstanceId=%s mandateId=%s isSysAdmin=%s prompt=%r",
|
||||
userId, featureInstanceId, mandateId, isSysAdmin, (currentPrompt or "")[:120],
|
||||
|
|
@ -960,6 +961,29 @@ class KnowledgeService:
|
|||
# Internal helpers
|
||||
# =============================================================================
|
||||
|
||||
# Markers added by the prompt enrichment layers (_enrichPromptWithFiles in
|
||||
# mainServiceAgent, data source sections in routeFeatureWorkspace). Used to
|
||||
# isolate the user's own words for the semantic search query.
|
||||
_USER_REQUEST_MARKER = "\n\nUser request: "
|
||||
_PROMPT_SECTION_MARKERS = ("\n\n[Active Data Sources]\n", "\n\n[Attached Feature Data Sources]\n")
|
||||
|
||||
|
||||
def _extractUserQuery(prompt: str) -> str:
|
||||
"""Isolate the user's actual question from an enriched agent prompt.
|
||||
|
||||
Enriched prompts wrap the user request in file metadata headers and
|
||||
data-source sections. Only the user's own words form a meaningful semantic
|
||||
search query - embedding the metadata would dilute the vector and can
|
||||
exceed the embedding model's input limit.
|
||||
"""
|
||||
if not prompt:
|
||||
return ""
|
||||
text = prompt.rsplit(_USER_REQUEST_MARKER, 1)[-1]
|
||||
for marker in _PROMPT_SECTION_MARKERS:
|
||||
text = text.split(marker, 1)[0]
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _estimateTokens(text: str) -> int:
|
||||
"""Estimate token count using character-based heuristic (1 token ~ 4 chars)."""
|
||||
return max(1, len(text) // CHARS_PER_TOKEN)
|
||||
|
|
|
|||
Loading…
Reference in a new issue