ux cleanup
This commit is contained in:
parent
535bd43174
commit
a1d9c68604
11 changed files with 721 additions and 33 deletions
171
modules/auth/trustedDeviceService.py
Normal file
171
modules/auth/trustedDeviceService.py
Normal file
|
|
@ -0,0 +1,171 @@
|
||||||
|
# Copyright (c) 2026 PowerOn AG
|
||||||
|
# All rights reserved.
|
||||||
|
"""
|
||||||
|
Trusted Device Service.
|
||||||
|
|
||||||
|
After successful MFA verification a device can be marked as trusted for a
|
||||||
|
configurable duration (default 60 days). On subsequent logins from the same
|
||||||
|
device the MFA step is skipped.
|
||||||
|
|
||||||
|
Cookie: ``mfa_trusted`` (httpOnly, Secure, SameSite policy from jwtService).
|
||||||
|
DB: ``TrustedDevice`` table in poweron_app.
|
||||||
|
|
||||||
|
Regulatory basis:
|
||||||
|
- NIST SP 800-63B Section 5.2.8: Verifier MAY re-authenticate only after a
|
||||||
|
configurable period when a device is bound to the subscriber.
|
||||||
|
- Microsoft, Google, AWS implement identical patterns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
|
||||||
|
from modules.shared.configuration import APP_CONFIG
|
||||||
|
from modules.shared.timeUtils import getUtcNow, getUtcTimestamp
|
||||||
|
from modules.datamodels.datamodelSecurity import TrustedDevice
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_COOKIE_NAME = "mfa_trusted"
|
||||||
|
_DEFAULT_TRUST_DAYS = 60
|
||||||
|
_TOKEN_BYTES = 32
|
||||||
|
|
||||||
|
|
||||||
|
def _getTrustDurationDays() -> int:
|
||||||
|
raw = (APP_CONFIG.get("MFA_TRUST_DURATION_DAYS") or "").strip()
|
||||||
|
if raw.isdigit() and int(raw) > 0:
|
||||||
|
return int(raw)
|
||||||
|
return _DEFAULT_TRUST_DAYS
|
||||||
|
|
||||||
|
|
||||||
|
def createTrustedDevice(userId: str, request: Request, response: Response, db) -> str:
|
||||||
|
"""Create a TrustedDevice entry and set the cookie on the response.
|
||||||
|
|
||||||
|
Returns the device token (cookie value).
|
||||||
|
"""
|
||||||
|
from modules.auth.jwtService import _cookiePolicy
|
||||||
|
|
||||||
|
trustDays = _getTrustDurationDays()
|
||||||
|
deviceToken = secrets.token_urlsafe(_TOKEN_BYTES)
|
||||||
|
|
||||||
|
now = getUtcTimestamp()
|
||||||
|
trustedUntil = now + (trustDays * 86400)
|
||||||
|
|
||||||
|
device = TrustedDevice(
|
||||||
|
id=deviceToken,
|
||||||
|
userId=userId,
|
||||||
|
trustedUntil=trustedUntil,
|
||||||
|
userAgent=(request.headers.get("user-agent") or "")[:512],
|
||||||
|
ipAddress=_getClientIp(request),
|
||||||
|
createdAt=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
db.recordCreate(TrustedDevice, device.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to persist TrustedDevice for userId={userId}: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
useSecure, samesite, _ = _cookiePolicy()
|
||||||
|
response.set_cookie(
|
||||||
|
key=_COOKIE_NAME,
|
||||||
|
value=deviceToken,
|
||||||
|
httponly=True,
|
||||||
|
secure=useSecure,
|
||||||
|
samesite=samesite,
|
||||||
|
path="/",
|
||||||
|
max_age=trustDays * 86400,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Trusted device created for userId={userId}, valid {trustDays}d")
|
||||||
|
return deviceToken
|
||||||
|
|
||||||
|
|
||||||
|
def isTrustedDevice(request: Request, userId: str, db) -> bool:
|
||||||
|
"""Check if the current request comes from a trusted device for the given user."""
|
||||||
|
deviceToken = request.cookies.get(_COOKIE_NAME)
|
||||||
|
if not deviceToken:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
records = db.getRecordset(
|
||||||
|
TrustedDevice,
|
||||||
|
recordFilter={"id": deviceToken, "userId": userId},
|
||||||
|
)
|
||||||
|
if not records:
|
||||||
|
return False
|
||||||
|
|
||||||
|
device = records[0]
|
||||||
|
trustedUntil = device.get("trustedUntil", 0)
|
||||||
|
if isinstance(trustedUntil, (int, float)) and trustedUntil > getUtcTimestamp():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error checking trusted device for userId={userId}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def revokeTrustedDevices(userId: str, db) -> int:
|
||||||
|
"""Revoke all trusted devices for a user. Returns count of deleted entries."""
|
||||||
|
try:
|
||||||
|
records = db.getRecordset(TrustedDevice, recordFilter={"userId": userId})
|
||||||
|
count = 0
|
||||||
|
for rec in records:
|
||||||
|
db.recordDelete(TrustedDevice, rec["id"])
|
||||||
|
count += 1
|
||||||
|
if count:
|
||||||
|
logger.info(f"Revoked {count} trusted device(s) for userId={userId}")
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to revoke trusted devices for userId={userId}: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def clearTrustedDeviceCookie(response: Response) -> None:
|
||||||
|
"""Clear the mfa_trusted cookie."""
|
||||||
|
from modules.auth.jwtService import _cookiePolicy
|
||||||
|
|
||||||
|
useSecure, samesite, samesiteHeader = _cookiePolicy()
|
||||||
|
secure_flag = "; Secure" if useSecure else ""
|
||||||
|
response.headers.append(
|
||||||
|
"Set-Cookie",
|
||||||
|
f"{_COOKIE_NAME}=deleted; Path=/; Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly{secure_flag}; SameSite={samesiteHeader}"
|
||||||
|
)
|
||||||
|
response.delete_cookie(
|
||||||
|
key=_COOKIE_NAME,
|
||||||
|
path="/",
|
||||||
|
secure=useSecure,
|
||||||
|
httponly=True,
|
||||||
|
samesite=samesite,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanupExpiredDevices(db) -> int:
|
||||||
|
"""Remove TrustedDevice entries past their trustedUntil. Returns deleted count."""
|
||||||
|
try:
|
||||||
|
records = db.getRecordset(TrustedDevice, recordFilter={})
|
||||||
|
now = getUtcTimestamp()
|
||||||
|
count = 0
|
||||||
|
for rec in records:
|
||||||
|
if rec.get("trustedUntil", 0) < now:
|
||||||
|
db.recordDelete(TrustedDevice, rec["id"])
|
||||||
|
count += 1
|
||||||
|
if count:
|
||||||
|
logger.info(f"Cleaned up {count} expired trusted device(s)")
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cleaning up expired trusted devices: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _getClientIp(request: Request) -> Optional[str]:
|
||||||
|
"""Extract client IP from request (respects X-Forwarded-For)."""
|
||||||
|
forwarded = request.headers.get("x-forwarded-for")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
if request.client:
|
||||||
|
return request.client.host
|
||||||
|
return None
|
||||||
|
|
@ -871,6 +871,7 @@ class DatabaseConnector:
|
||||||
("jsonb", "TEXT"): "TEXT USING \"{col}\"::text",
|
("jsonb", "TEXT"): "TEXT USING \"{col}\"::text",
|
||||||
("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE,
|
("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE,
|
||||||
("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer",
|
("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer",
|
||||||
|
("text", "BOOLEAN"): "BOOLEAN USING CASE WHEN \"{col}\" IN ('true', '1', 't', 'yes') THEN TRUE ELSE FALSE END",
|
||||||
("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')',
|
("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')',
|
||||||
("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")',
|
("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")',
|
||||||
("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')',
|
("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')',
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,21 @@ from modules.shared.voiceCatalog import getDefaultVoice
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_STT_LANGUAGE_MAP = {
|
||||||
|
"de-CH": "de-DE",
|
||||||
|
"de-AT": "de-DE",
|
||||||
|
"en-GB": "en-US",
|
||||||
|
"en-AU": "en-US",
|
||||||
|
"fr-CH": "fr-FR",
|
||||||
|
"it-CH": "it-IT",
|
||||||
|
"pt-BR": "pt-PT",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalizeSttLanguage(language: str) -> str:
|
||||||
|
"""Map regional language variants to codes supported by Google STT models."""
|
||||||
|
return _STT_LANGUAGE_MAP.get(language, language)
|
||||||
|
|
||||||
|
|
||||||
def _buildPrimarySttRecognitionFields(
|
def _buildPrimarySttRecognitionFields(
|
||||||
*,
|
*,
|
||||||
|
|
@ -116,6 +131,7 @@ class ConnectorGoogleSpeech:
|
||||||
Returns:
|
Returns:
|
||||||
Dict containing transcribed text, confidence, and metadata
|
Dict containing transcribed text, confidence, and metadata
|
||||||
"""
|
"""
|
||||||
|
language = _normalizeSttLanguage(language)
|
||||||
try:
|
try:
|
||||||
# Treat sampleRate=0 as unknown (invalid value from client)
|
# Treat sampleRate=0 as unknown (invalid value from client)
|
||||||
if sampleRate is not None and sampleRate <= 0:
|
if sampleRate is not None and sampleRate <= 0:
|
||||||
|
|
@ -480,6 +496,7 @@ class ConnectorGoogleSpeech:
|
||||||
Dicts with keys: isFinal, transcript, confidence, stabilityScore, audioDurationSec;
|
Dicts with keys: isFinal, transcript, confidence, stabilityScore, audioDurationSec;
|
||||||
optionally endOfSingleUtterance, reconnectRequired
|
optionally endOfSingleUtterance, reconnectRequired
|
||||||
"""
|
"""
|
||||||
|
language = _normalizeSttLanguage(language)
|
||||||
STREAM_LIMIT_SEC = 290
|
STREAM_LIMIT_SEC = 290
|
||||||
streamStartTs = time.time()
|
streamStartTs = time.time()
|
||||||
totalAudioBytes = 0
|
totalAudioBytes = 0
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,43 @@ class Token(PowerOnModel):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@i18nModel("Vertrauenswuerdiges Geraet")
|
||||||
|
class TrustedDevice(PowerOnModel):
|
||||||
|
"""A device trusted after successful MFA verification (skips MFA for configured duration)."""
|
||||||
|
id: str = Field(
|
||||||
|
default_factory=lambda: str(uuid.uuid4()),
|
||||||
|
description="Random token stored as httpOnly cookie value",
|
||||||
|
json_schema_extra={"label": "ID"},
|
||||||
|
)
|
||||||
|
userId: str = Field(
|
||||||
|
...,
|
||||||
|
description="User this trusted device belongs to",
|
||||||
|
json_schema_extra={"label": "Benutzer-ID", "fk_target": {"db": "poweron_app", "table": "UserInDB", "labelField": "username"}},
|
||||||
|
)
|
||||||
|
trustedUntil: float = Field(
|
||||||
|
...,
|
||||||
|
description="UTC timestamp until which the device is trusted",
|
||||||
|
json_schema_extra={"label": "Vertrauenswuerdig bis", "frontend_type": "timestamp"},
|
||||||
|
)
|
||||||
|
userAgent: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Browser user agent at time of trust grant",
|
||||||
|
json_schema_extra={"label": "User-Agent"},
|
||||||
|
)
|
||||||
|
ipAddress: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="IP address at time of trust grant",
|
||||||
|
json_schema_extra={"label": "IP-Adresse"},
|
||||||
|
)
|
||||||
|
createdAt: float = Field(
|
||||||
|
default_factory=getUtcTimestamp,
|
||||||
|
description="When the device was trusted",
|
||||||
|
json_schema_extra={"label": "Erstellt am", "frontend_type": "timestamp"},
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(use_enum_values=True)
|
||||||
|
|
||||||
|
|
||||||
@i18nModel("Authentifizierungsereignis")
|
@i18nModel("Authentifizierungsereignis")
|
||||||
class AuthEvent(PowerOnModel):
|
class AuthEvent(PowerOnModel):
|
||||||
"""Authentication event for audit logging."""
|
"""Authentication event for audit logging."""
|
||||||
|
|
|
||||||
|
|
@ -489,6 +489,158 @@ def _collectPriorFileIds(chatInterface, workflowId: str) -> List[str]:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Default context budget for prior files (metadata injection per file costs tokens
|
||||||
|
# in EVERY agent round). Override per workspace instance via instanceConfig
|
||||||
|
# key "maxPriorContextFiles". Files outside the budget remain fully accessible
|
||||||
|
# to the agent via conversation history and readFile - nothing is lost.
|
||||||
|
DEFAULT_MAX_PRIOR_CONTEXT_FILES = 10
|
||||||
|
|
||||||
|
|
||||||
|
async def _selectPriorFilesForContext(
|
||||||
|
priorFileIds: List[str],
|
||||||
|
prompt: str,
|
||||||
|
aiObjects,
|
||||||
|
mandateId: str,
|
||||||
|
featureInstanceId: str,
|
||||||
|
maxFiles: int,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Select which prior files get their metadata injected into the agent context.
|
||||||
|
|
||||||
|
Selection only happens when more candidates exist than the budget allows -
|
||||||
|
otherwise all files pass through untouched. Selection order:
|
||||||
|
1. Files whose indexed chunks are semantically relevant to the current prompt
|
||||||
|
2. Remaining budget filled with the most recent files (so files without an
|
||||||
|
index entry are never unfairly dropped)
|
||||||
|
"""
|
||||||
|
if len(priorFileIds) <= maxFiles:
|
||||||
|
return priorFileIds
|
||||||
|
|
||||||
|
rankedFileIds: List[str] = []
|
||||||
|
try:
|
||||||
|
embeddingResponse = await aiObjects.callEmbedding([prompt])
|
||||||
|
embeddings = (embeddingResponse.metadata or {}).get("embeddings", [])
|
||||||
|
if embeddings:
|
||||||
|
knowledgeIf = getKnowledgeInterface()
|
||||||
|
results = knowledgeIf.semanticSearch(
|
||||||
|
queryVector=embeddings[0],
|
||||||
|
featureInstanceId=featureInstanceId,
|
||||||
|
mandateId=mandateId,
|
||||||
|
limit=maxFiles * 3,
|
||||||
|
)
|
||||||
|
priorSet = set(priorFileIds)
|
||||||
|
for chunk in results:
|
||||||
|
fid = chunk.get("fileId") if isinstance(chunk, dict) else getattr(chunk, "fileId", None)
|
||||||
|
if fid and fid in priorSet and fid not in rankedFileIds:
|
||||||
|
rankedFileIds.append(fid)
|
||||||
|
if len(rankedFileIds) >= maxFiles:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Relevance ranking for prior files failed: {e}")
|
||||||
|
|
||||||
|
for fid in reversed(priorFileIds):
|
||||||
|
if len(rankedFileIds) >= maxFiles:
|
||||||
|
break
|
||||||
|
if fid not in rankedFileIds:
|
||||||
|
rankedFileIds.append(fid)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Prior-file selection: {len(priorFileIds)} candidates -> {len(rankedFileIds)} in context "
|
||||||
|
f"(budget={maxFiles}, relevance-ranked, recency fill)"
|
||||||
|
)
|
||||||
|
return rankedFileIds
|
||||||
|
|
||||||
|
|
||||||
|
async def _ensureFilesIndexed(
|
||||||
|
fileIds: List[str],
|
||||||
|
user,
|
||||||
|
mandateId: str,
|
||||||
|
featureInstanceId: str,
|
||||||
|
) -> int:
|
||||||
|
"""Ensure all attached files have embeddings in the knowledge store.
|
||||||
|
|
||||||
|
Checks FileContentIndex for each file. Files not yet indexed are extracted
|
||||||
|
and embedded inline so that subsequent RAG queries can find their content.
|
||||||
|
Indexing is idempotent (content-hash check in requestIngestion), so each
|
||||||
|
file is only ever processed once - re-attaching an indexed file is a no-op.
|
||||||
|
|
||||||
|
Returns the number of files that were newly indexed.
|
||||||
|
"""
|
||||||
|
if not fileIds:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
knowledgeIf = getKnowledgeInterface(user)
|
||||||
|
unindexedIds = []
|
||||||
|
for fid in fileIds:
|
||||||
|
existing = knowledgeIf.getFileContentIndex(fid)
|
||||||
|
status = (existing.get("status") if isinstance(existing, dict) else getattr(existing, "status", "")) if existing else ""
|
||||||
|
if status not in ("indexed", "embedding"):
|
||||||
|
unindexedIds.append(fid)
|
||||||
|
|
||||||
|
if not unindexedIds:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
logger.info(f"Ensure-embed: {len(unindexedIds)}/{len(fileIds)} files need indexing")
|
||||||
|
|
||||||
|
from modules.serviceCenter import getService
|
||||||
|
from modules.serviceCenter.context import ServiceCenterContext
|
||||||
|
indexCtx = ServiceCenterContext(
|
||||||
|
user=user, mandateId=mandateId, featureInstanceId=featureInstanceId,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
knowledgeService = getService("knowledge", indexCtx)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Ensure-embed: knowledge service unavailable: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
chatInterface = interfaceDbChat.getInterface(user)
|
||||||
|
indexed = 0
|
||||||
|
for fid in unindexedIds:
|
||||||
|
try:
|
||||||
|
fileInfo = chatInterface.getFileInfo(fid) if chatInterface else None
|
||||||
|
if not fileInfo:
|
||||||
|
continue
|
||||||
|
fileName = fileInfo.get("fileName", "")
|
||||||
|
mimeType = fileInfo.get("mimeType", "")
|
||||||
|
rawBytes = chatInterface.getFileData(fid)
|
||||||
|
if not rawBytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
extractionService = getService("extraction", indexCtx)
|
||||||
|
extracted = extractionService.extractContentFromBytes(
|
||||||
|
rawBytes, fileName, mimeType, documentId=fid,
|
||||||
|
)
|
||||||
|
contentObjects = [
|
||||||
|
{
|
||||||
|
"contentType": getattr(p, "contentType", "text"),
|
||||||
|
"data": getattr(p, "data", "") or "",
|
||||||
|
"contentObjectId": getattr(p, "contentObjectId", "") or str(uuid.uuid4()),
|
||||||
|
"contextRef": getattr(p, "contextRef", {}) or {},
|
||||||
|
}
|
||||||
|
for p in (extracted.parts or [])
|
||||||
|
if getattr(p, "data", None)
|
||||||
|
]
|
||||||
|
if not contentObjects:
|
||||||
|
continue
|
||||||
|
|
||||||
|
await knowledgeService.indexFile(
|
||||||
|
fileId=fid,
|
||||||
|
fileName=fileName,
|
||||||
|
mimeType=mimeType,
|
||||||
|
userId=user.id if user else "",
|
||||||
|
featureInstanceId=featureInstanceId,
|
||||||
|
mandateId=mandateId,
|
||||||
|
contentObjects=contentObjects,
|
||||||
|
structure=getattr(extracted, "structure", None),
|
||||||
|
)
|
||||||
|
indexed += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Ensure-embed: skipping file {fid}: {e}")
|
||||||
|
|
||||||
|
if indexed:
|
||||||
|
logger.info(f"Ensure-embed: indexed {indexed} file(s) before agent start")
|
||||||
|
return indexed
|
||||||
|
|
||||||
|
|
||||||
async def _deriveWorkflowName(prompt: str, aiService) -> str:
|
async def _deriveWorkflowName(prompt: str, aiService) -> str:
|
||||||
"""Use AI to generate a concise workflow title from the user prompt."""
|
"""Use AI to generate a concise workflow title from the user prompt."""
|
||||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum
|
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum
|
||||||
|
|
@ -740,20 +892,29 @@ async def _runWorkspaceAgent(
|
||||||
|
|
||||||
priorFileIds = _collectPriorFileIds(chatInterface, workflowId)
|
priorFileIds = _collectPriorFileIds(chatInterface, workflowId)
|
||||||
currentFileIdSet = set(fileIds or [])
|
currentFileIdSet = set(fileIds or [])
|
||||||
mergedFileIds = list(fileIds or [])
|
candidatePriorIds = [pf for pf in priorFileIds if pf not in currentFileIdSet]
|
||||||
for pf in priorFileIds:
|
|
||||||
if pf not in currentFileIdSet:
|
# Embed-first rule: newly attached files are indexed into the knowledge
|
||||||
mergedFileIds.append(pf)
|
# store BEFORE the agent starts, so RAG retrieval works from round 1.
|
||||||
if len(mergedFileIds) > len(fileIds or []):
|
await _ensureFilesIndexed(fileIds or [], user, mandateId, instanceId)
|
||||||
|
|
||||||
|
_cfg = instanceConfig or {}
|
||||||
|
|
||||||
|
if candidatePriorIds:
|
||||||
|
maxPriorFiles = int(_cfg.get("maxPriorContextFiles", DEFAULT_MAX_PRIOR_CONTEXT_FILES))
|
||||||
|
candidatePriorIds = await _selectPriorFilesForContext(
|
||||||
|
candidatePriorIds, prompt, aiObjects, mandateId, instanceId, maxPriorFiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
mergedFileIds = list(fileIds or []) + candidatePriorIds
|
||||||
|
if candidatePriorIds:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Merged {len(mergedFileIds) - len(fileIds or [])} prior file(s) into agent context "
|
f"Merged {len(candidatePriorIds)} prior file(s) into agent context "
|
||||||
f"(total: {len(mergedFileIds)}) for workflow {workflowId}"
|
f"(total: {len(mergedFileIds)}) for workflow {workflowId}"
|
||||||
)
|
)
|
||||||
|
|
||||||
accumulatedText = ""
|
accumulatedText = ""
|
||||||
messagePersisted = False
|
messagePersisted = False
|
||||||
|
|
||||||
_cfg = instanceConfig or {}
|
|
||||||
_toolSet = _cfg.get("toolSet", "core")
|
_toolSet = _cfg.get("toolSet", "core")
|
||||||
_agentCfg = _cfg.get("agentConfig")
|
_agentCfg = _cfg.get("agentConfig")
|
||||||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentConfig
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentConfig
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,9 @@ import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
import base64
|
import base64
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
from collections import OrderedDict
|
||||||
from typing import Dict, Any, List, Union, Tuple, Optional, Callable, AsyncGenerator
|
from typing import Dict, Any, List, Union, Tuple, Optional, Callable, AsyncGenerator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import time
|
import time
|
||||||
|
|
@ -542,7 +544,27 @@ class AiObjects:
|
||||||
else:
|
else:
|
||||||
options.operationType = OperationTypeEnum.EMBEDDING
|
options.operationType = OperationTypeEnum.EMBEDDING
|
||||||
|
|
||||||
combinedText = " ".join(texts[:3])[:500]
|
# Serve known vectors from cache; only unknown texts go to the API.
|
||||||
|
resolvedVectors: Dict[int, List[float]] = {}
|
||||||
|
pendingTexts: List[str] = []
|
||||||
|
pendingPositions: List[int] = []
|
||||||
|
for i, t in enumerate(texts):
|
||||||
|
cachedVector = _embeddingCacheGet(t)
|
||||||
|
if cachedVector is not None:
|
||||||
|
resolvedVectors[i] = cachedVector
|
||||||
|
else:
|
||||||
|
pendingTexts.append(t)
|
||||||
|
pendingPositions.append(i)
|
||||||
|
|
||||||
|
if not pendingTexts:
|
||||||
|
logger.debug(f"Embedding cache hit for all {len(texts)} text(s)")
|
||||||
|
return AiCallResponse(
|
||||||
|
content="", modelName="embedding-cache", priceCHF=0.0,
|
||||||
|
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=0,
|
||||||
|
metadata={"embeddings": [resolvedVectors[i] for i in range(len(texts))]},
|
||||||
|
)
|
||||||
|
|
||||||
|
combinedText = " ".join(pendingTexts[:3])[:500]
|
||||||
availableModels = modelRegistry.getAvailableModels()
|
availableModels = modelRegistry.getAvailableModels()
|
||||||
|
|
||||||
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
||||||
|
|
@ -575,13 +597,13 @@ class AiObjects:
|
||||||
for attempt, model in enumerate(failoverModelList):
|
for attempt, model in enumerate(failoverModelList):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Embedding call with {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
|
logger.info(f"Embedding call with {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
|
||||||
inputBytes = sum(len(t.encode("utf-8")) for t in texts)
|
inputBytes = sum(len(t.encode("utf-8")) for t in pendingTexts)
|
||||||
startTime = time.time()
|
startTime = time.time()
|
||||||
|
|
||||||
batches = _buildEmbeddingBatches(texts, model.contextLength)
|
batches = _buildEmbeddingBatches(pendingTexts, model.contextLength)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Embedding: {len(texts)} texts -> {len(batches)} batch(es), "
|
f"Embedding: {len(pendingTexts)} texts ({len(resolvedVectors)} cached) -> "
|
||||||
f"model contextLength={model.contextLength}"
|
f"{len(batches)} batch(es), model contextLength={model.contextLength}"
|
||||||
)
|
)
|
||||||
|
|
||||||
allEmbeddings: List[List[float]] = []
|
allEmbeddings: List[List[float]] = []
|
||||||
|
|
@ -606,11 +628,17 @@ class AiObjects:
|
||||||
if totalPriceCHF == 0.0:
|
if totalPriceCHF == 0.0:
|
||||||
totalPriceCHF = model.calculatepriceCHF(processingTime, inputBytes, 0)
|
totalPriceCHF = model.calculatepriceCHF(processingTime, inputBytes, 0)
|
||||||
|
|
||||||
|
for j, position in enumerate(pendingPositions):
|
||||||
|
if j < len(allEmbeddings):
|
||||||
|
resolvedVectors[position] = allEmbeddings[j]
|
||||||
|
_embeddingCachePut(pendingTexts[j], allEmbeddings[j])
|
||||||
|
mergedEmbeddings = [resolvedVectors.get(i, []) for i in range(len(texts))]
|
||||||
|
|
||||||
response = AiCallResponse(
|
response = AiCallResponse(
|
||||||
content="", modelName=model.name, provider=model.connectorType,
|
content="", modelName=model.name, provider=model.connectorType,
|
||||||
priceCHF=totalPriceCHF, processingTime=processingTime,
|
priceCHF=totalPriceCHF, processingTime=processingTime,
|
||||||
bytesSent=inputBytes, bytesReceived=0, errorCount=0,
|
bytesSent=inputBytes, bytesReceived=0, errorCount=0,
|
||||||
metadata={"embeddings": allEmbeddings}
|
metadata={"embeddings": mergedEmbeddings}
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.billingCallback:
|
if self.billingCallback:
|
||||||
|
|
@ -681,6 +709,28 @@ class AiObjects:
|
||||||
_CHARS_PER_TOKEN = 4
|
_CHARS_PER_TOKEN = 4
|
||||||
_SAFETY_MARGIN = 0.90
|
_SAFETY_MARGIN = 0.90
|
||||||
|
|
||||||
|
# In-process cache for embedding vectors. Identical texts (e.g. the same user
|
||||||
|
# prompt embedded once for prior-file selection and once for RAG context
|
||||||
|
# building) hit the cache instead of paying for a second API call.
|
||||||
|
_EMBEDDING_CACHE_MAX_ENTRIES = 256
|
||||||
|
_embeddingCache: OrderedDict = OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
|
def _embeddingCacheGet(text: str) -> Optional[List[float]]:
|
||||||
|
key = hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||||
|
vector = _embeddingCache.get(key)
|
||||||
|
if vector is not None:
|
||||||
|
_embeddingCache.move_to_end(key)
|
||||||
|
return vector
|
||||||
|
|
||||||
|
|
||||||
|
def _embeddingCachePut(text: str, vector: List[float]) -> None:
|
||||||
|
key = hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||||
|
_embeddingCache[key] = vector
|
||||||
|
_embeddingCache.move_to_end(key)
|
||||||
|
while len(_embeddingCache) > _EMBEDDING_CACHE_MAX_ENTRIES:
|
||||||
|
_embeddingCache.popitem(last=False)
|
||||||
|
|
||||||
|
|
||||||
def _estimateTokens(text: str) -> int:
|
def _estimateTokens(text: str) -> int:
|
||||||
"""Rough token estimate: 1 token ~ 4 characters."""
|
"""Rough token estimate: 1 token ~ 4 characters."""
|
||||||
|
|
@ -691,9 +741,7 @@ def _buildEmbeddingBatches(texts: List[str], contextLength: int) -> List[List[st
|
||||||
"""Split a list of texts into batches whose total estimated token count
|
"""Split a list of texts into batches whose total estimated token count
|
||||||
stays within the model's contextLength (with safety margin).
|
stays within the model's contextLength (with safety margin).
|
||||||
|
|
||||||
Each individual text is assumed to already be within limits (enforced by
|
Texts that individually exceed the per-input limit are truncated to fit.
|
||||||
the chunking layer). If a single text exceeds the budget, it is placed
|
|
||||||
in its own batch as a last resort.
|
|
||||||
"""
|
"""
|
||||||
if not texts:
|
if not texts:
|
||||||
return []
|
return []
|
||||||
|
|
@ -701,11 +749,21 @@ def _buildEmbeddingBatches(texts: List[str], contextLength: int) -> List[List[st
|
||||||
return [texts]
|
return [texts]
|
||||||
|
|
||||||
maxTokensPerBatch = int(contextLength * _SAFETY_MARGIN)
|
maxTokensPerBatch = int(contextLength * _SAFETY_MARGIN)
|
||||||
|
maxCharsPerInput = maxTokensPerBatch * _CHARS_PER_TOKEN
|
||||||
batches: List[List[str]] = []
|
batches: List[List[str]] = []
|
||||||
currentBatch: List[str] = []
|
currentBatch: List[str] = []
|
||||||
currentTokens = 0
|
currentTokens = 0
|
||||||
|
|
||||||
for text in texts:
|
for text in texts:
|
||||||
|
if len(text) > maxCharsPerInput:
|
||||||
|
# API hard limit per input. File content never hits this (the
|
||||||
|
# chunking layer splits at ~400 tokens); only oversized search
|
||||||
|
# queries can, where truncation is semantically acceptable.
|
||||||
|
logger.warning(
|
||||||
|
f"Embedding input truncated from {len(text)} to {maxCharsPerInput} chars "
|
||||||
|
f"(model input limit {contextLength} tokens)"
|
||||||
|
)
|
||||||
|
text = text[:maxCharsPerInput]
|
||||||
textTokens = _estimateTokens(text)
|
textTokens = _estimateTokens(text)
|
||||||
if currentBatch and (currentTokens + textTokens) > maxTokensPerBatch:
|
if currentBatch and (currentTokens + textTokens) > maxTokensPerBatch:
|
||||||
batches.append(currentBatch)
|
batches.append(currentBatch)
|
||||||
|
|
|
||||||
|
|
@ -3112,8 +3112,12 @@ class AppObjects:
|
||||||
|
|
||||||
# Token methods
|
# Token methods
|
||||||
|
|
||||||
def saveAccessToken(self, token: Token, replace_existing: bool = True) -> None:
|
def saveAccessToken(self, token: Token, replace_existing: bool = False) -> None:
|
||||||
"""Save an access token for the current user (must NOT have connectionId)"""
|
"""Save an access token for the current user (must NOT have connectionId).
|
||||||
|
|
||||||
|
Multi-session: replace_existing=False (default) keeps existing sessions alive.
|
||||||
|
Only set replace_existing=True for explicit single-session scenarios.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Validate that this is NOT a connection token
|
# Validate that this is NOT a connection token
|
||||||
if token.connectionId:
|
if token.connectionId:
|
||||||
|
|
|
||||||
185
modules/routes/routeAdminSessions.py
Normal file
185
modules/routes/routeAdminSessions.py
Normal file
|
|
@ -0,0 +1,185 @@
|
||||||
|
# Copyright (c) 2026 PowerOn AG
|
||||||
|
# All rights reserved.
|
||||||
|
"""
|
||||||
|
Admin endpoints for session and trusted device management.
|
||||||
|
|
||||||
|
Allows mandate-admins and platform-admins to view and revoke active sessions
|
||||||
|
and trusted devices for users under their jurisdiction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, status, Depends, Request, Query
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from modules.auth import limiter, getCurrentUser
|
||||||
|
from modules.datamodels.datamodelUam import User
|
||||||
|
from modules.datamodels.datamodelSecurity import Token, TokenPurpose, TokenStatus, TrustedDevice
|
||||||
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||||
|
from modules.shared.timeUtils import getUtcTimestamp
|
||||||
|
from modules.shared.i18nRegistry import apiRouteContext
|
||||||
|
|
||||||
|
routeApiMsg = apiRouteContext("routeAdminSessions")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/api/admin/sessions",
|
||||||
|
tags=["Admin Sessions"],
|
||||||
|
responses={404: {"description": "Not found"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _requireAdmin(currentUser: User) -> None:
|
||||||
|
"""Ensure the caller is a platform admin or sysAdmin."""
|
||||||
|
if not (getattr(currentUser, "isPlatformAdmin", False) or getattr(currentUser, "isSysAdmin", False)):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=routeApiMsg("Only platform admins can manage sessions"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
def listSessions(
|
||||||
|
request: Request,
|
||||||
|
userId: str = Query(..., description="User ID whose sessions to list"),
|
||||||
|
currentUser: User = Depends(getCurrentUser),
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""List active auth sessions for a user."""
|
||||||
|
_requireAdmin(currentUser)
|
||||||
|
rootInterface = getRootInterface()
|
||||||
|
|
||||||
|
tokens = rootInterface.db.getRecordset(
|
||||||
|
Token,
|
||||||
|
recordFilter={
|
||||||
|
"userId": userId,
|
||||||
|
"tokenPurpose": TokenPurpose.AUTH_SESSION.value,
|
||||||
|
"status": TokenStatus.ACTIVE.value,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
now = getUtcTimestamp()
|
||||||
|
result = []
|
||||||
|
for t in tokens:
|
||||||
|
expiresAt = t.get("expiresAt", 0)
|
||||||
|
if expiresAt < now:
|
||||||
|
continue
|
||||||
|
result.append({
|
||||||
|
"sessionId": t.get("sessionId"),
|
||||||
|
"tokenId": t.get("id"),
|
||||||
|
"authority": t.get("authority"),
|
||||||
|
"createdAt": t.get("sysCreatedAt"),
|
||||||
|
"expiresAt": expiresAt,
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{sessionId}")
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
def revokeSession(
|
||||||
|
request: Request,
|
||||||
|
sessionId: str,
|
||||||
|
currentUser: User = Depends(getCurrentUser),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Revoke a single session by sessionId."""
|
||||||
|
_requireAdmin(currentUser)
|
||||||
|
rootInterface = getRootInterface()
|
||||||
|
|
||||||
|
tokens = rootInterface.db.getRecordset(
|
||||||
|
Token,
|
||||||
|
recordFilter={"sessionId": sessionId, "tokenPurpose": TokenPurpose.AUTH_SESSION.value},
|
||||||
|
)
|
||||||
|
count = 0
|
||||||
|
for t in tokens:
|
||||||
|
rootInterface.db.recordDelete(Token, t["id"])
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if count == 0:
|
||||||
|
raise HTTPException(status_code=404, detail=routeApiMsg("Session not found"))
|
||||||
|
|
||||||
|
logger.info(f"Admin {currentUser.username} revoked session {sessionId} ({count} token(s))")
|
||||||
|
return {"revoked": count, "sessionId": sessionId}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("")
|
||||||
|
@limiter.limit("10/minute")
|
||||||
|
def revokeAllSessions(
|
||||||
|
request: Request,
|
||||||
|
userId: str = Query(..., description="User ID whose sessions to revoke"),
|
||||||
|
currentUser: User = Depends(getCurrentUser),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Revoke ALL active sessions for a user (force logout everywhere)."""
|
||||||
|
_requireAdmin(currentUser)
|
||||||
|
rootInterface = getRootInterface()
|
||||||
|
|
||||||
|
tokens = rootInterface.db.getRecordset(
|
||||||
|
Token,
|
||||||
|
recordFilter={
|
||||||
|
"userId": userId,
|
||||||
|
"tokenPurpose": TokenPurpose.AUTH_SESSION.value,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
count = 0
|
||||||
|
for t in tokens:
|
||||||
|
rootInterface.db.recordDelete(Token, t["id"])
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
logger.info(f"Admin {currentUser.username} revoked all sessions for userId={userId} ({count} token(s))")
|
||||||
|
return {"revoked": count, "userId": userId}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Trusted Devices ---
|
||||||
|
|
||||||
|
trustedDeviceRouter = APIRouter(
|
||||||
|
prefix="/api/admin/trusted-devices",
|
||||||
|
tags=["Admin Sessions"],
|
||||||
|
responses={404: {"description": "Not found"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@trustedDeviceRouter.get("")
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
def listTrustedDevices(
|
||||||
|
request: Request,
|
||||||
|
userId: str = Query(..., description="User ID whose trusted devices to list"),
|
||||||
|
currentUser: User = Depends(getCurrentUser),
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""List trusted devices for a user."""
|
||||||
|
_requireAdmin(currentUser)
|
||||||
|
rootInterface = getRootInterface()
|
||||||
|
|
||||||
|
devices = rootInterface.db.getRecordset(
|
||||||
|
TrustedDevice, recordFilter={"userId": userId}
|
||||||
|
)
|
||||||
|
|
||||||
|
now = getUtcTimestamp()
|
||||||
|
result = []
|
||||||
|
for d in devices:
|
||||||
|
result.append({
|
||||||
|
"id": d.get("id", "")[:8] + "...",
|
||||||
|
"trustedUntil": d.get("trustedUntil"),
|
||||||
|
"isExpired": d.get("trustedUntil", 0) < now,
|
||||||
|
"userAgent": d.get("userAgent"),
|
||||||
|
"ipAddress": d.get("ipAddress"),
|
||||||
|
"createdAt": d.get("createdAt"),
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@trustedDeviceRouter.delete("")
|
||||||
|
@limiter.limit("10/minute")
|
||||||
|
def revokeAllTrustedDevices(
|
||||||
|
request: Request,
|
||||||
|
userId: str = Query(..., description="User ID whose trusted devices to revoke"),
|
||||||
|
currentUser: User = Depends(getCurrentUser),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Revoke ALL trusted devices for a user (force MFA on next login)."""
|
||||||
|
_requireAdmin(currentUser)
|
||||||
|
rootInterface = getRootInterface()
|
||||||
|
|
||||||
|
from modules.auth.trustedDeviceService import revokeTrustedDevices
|
||||||
|
count = revokeTrustedDevices(userId, rootInterface.db)
|
||||||
|
|
||||||
|
logger.info(f"Admin {currentUser.username} revoked all trusted devices for userId={userId} ({count})")
|
||||||
|
return {"revoked": count, "userId": userId}
|
||||||
|
|
@ -232,6 +232,13 @@ def mfaVerify(
|
||||||
|
|
||||||
logger.info("MFA verify successful for user %s", username)
|
logger.info("MFA verify successful for user %s", username)
|
||||||
|
|
||||||
|
# Mark device as trusted so MFA is skipped on next login from this device
|
||||||
|
try:
|
||||||
|
from modules.auth.trustedDeviceService import createTrustedDevice
|
||||||
|
createTrustedDevice(userId, request, response, rootInterface.db)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to create trusted device after MFA verify: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from modules.dbHelpers.auditLogger import audit_logger
|
from modules.dbHelpers.auditLogger import audit_logger
|
||||||
audit_logger.logUserAccess(
|
audit_logger.logUserAccess(
|
||||||
|
|
|
||||||
|
|
@ -256,6 +256,7 @@ def login(
|
||||||
|
|
||||||
# --- MFA gate --------------------------------------------------------
|
# --- MFA gate --------------------------------------------------------
|
||||||
from modules.auth.mfaService import isMfaRequired as _isMfaRequired
|
from modules.auth.mfaService import isMfaRequired as _isMfaRequired
|
||||||
|
from modules.auth.trustedDeviceService import isTrustedDevice as _isTrustedDevice
|
||||||
from modules.routes.routeMfa import createMfaPendingToken
|
from modules.routes.routeMfa import createMfaPendingToken
|
||||||
|
|
||||||
userRecord = rootInterface._getUserForAuthentication(user.username)
|
userRecord = rootInterface._getUserForAuthentication(user.username)
|
||||||
|
|
@ -273,7 +274,14 @@ def login(
|
||||||
mfaRequired = _isMfaRequired(user, userMandates=userMandates, mandates=mandateObjs)
|
mfaRequired = _isMfaRequired(user, userMandates=userMandates, mandates=mandateObjs)
|
||||||
hasMfaSetup = bool(userRecord and userRecord.get("mfaSecret") and getattr(user, "mfaEnabled", False))
|
hasMfaSetup = bool(userRecord and userRecord.get("mfaSecret") and getattr(user, "mfaEnabled", False))
|
||||||
|
|
||||||
|
# Trusted device: skip MFA if the device was previously verified
|
||||||
|
_deviceTrusted = False
|
||||||
if mfaRequired or hasMfaSetup:
|
if mfaRequired or hasMfaSetup:
|
||||||
|
_deviceTrusted = _isTrustedDevice(request, str(user.id), rootInterface.db)
|
||||||
|
if _deviceTrusted:
|
||||||
|
logger.info(f"MFA skipped for user {user.username} (trusted device)")
|
||||||
|
|
||||||
|
if (mfaRequired or hasMfaSetup) and not _deviceTrusted:
|
||||||
_sid = str(uuid.uuid4())
|
_sid = str(uuid.uuid4())
|
||||||
pendingToken = createMfaPendingToken(
|
pendingToken = createMfaPendingToken(
|
||||||
userId=str(user.id),
|
userId=str(user.id),
|
||||||
|
|
@ -659,31 +667,46 @@ def refresh_token(
|
||||||
logger.error(f"Failed to get user from database: {str(e)}")
|
logger.error(f"Failed to get user from database: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=routeApiMsg("Failed to validate user"))
|
raise HTTPException(status_code=500, detail=routeApiMsg("Failed to validate user"))
|
||||||
|
|
||||||
|
# Preserve sessionId from the refresh token so the session stays grouped
|
||||||
|
sessionId = payload.get("sid") or str(uuid.uuid4())
|
||||||
|
|
||||||
# Create new token data
|
# Create new token data
|
||||||
# MULTI-TENANT: Token does NOT contain mandateId anymore
|
newJti = str(uuid.uuid4())
|
||||||
token_data = {
|
token_data = {
|
||||||
"sub": current_user.username,
|
"sub": current_user.username,
|
||||||
"userId": str(current_user.id),
|
"userId": str(current_user.id),
|
||||||
"authenticationAuthority": current_user.authenticationAuthority
|
"authenticationAuthority": current_user.authenticationAuthority,
|
||||||
# NO mandateId in token
|
"jti": newJti,
|
||||||
|
"sid": sessionId,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create new access token + set cookie
|
# Create new access token + set cookie
|
||||||
access_token, _expires = createAccessToken(token_data)
|
access_token, accessExpires = createAccessToken(token_data)
|
||||||
setAccessTokenCookie(response, access_token)
|
setAccessTokenCookie(response, access_token)
|
||||||
|
|
||||||
# Get expiration time
|
# Persist the new token in DB so _getUserBase() accepts it
|
||||||
|
authority = current_user.authenticationAuthority
|
||||||
|
if isinstance(authority, str):
|
||||||
|
authority = AuthAuthority(authority)
|
||||||
|
dbToken = Token(
|
||||||
|
id=newJti,
|
||||||
|
userId=str(current_user.id),
|
||||||
|
authority=authority,
|
||||||
|
tokenAccess=access_token,
|
||||||
|
tokenPurpose=TokenPurpose.AUTH_SESSION,
|
||||||
|
expiresAt=accessExpires.timestamp(),
|
||||||
|
sessionId=sessionId,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])
|
userInterface = getInterface(current_user)
|
||||||
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
userInterface.saveAccessToken(dbToken)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to decode new access token: {str(e)}")
|
logger.warning(f"Failed to persist refreshed token in DB: {e}")
|
||||||
raise HTTPException(status_code=500, detail=routeApiMsg("Failed to create new token"))
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"type": "token_refresh_success",
|
"type": "token_refresh_success",
|
||||||
"message": "Token refreshed successfully",
|
"message": "Token refreshed successfully",
|
||||||
"expires_at": expires_at.isoformat()
|
"expires_at": accessExpires.isoformat()
|
||||||
}
|
}
|
||||||
|
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
|
|
|
||||||
|
|
@ -638,7 +638,8 @@ class KnowledgeService:
|
||||||
Returns:
|
Returns:
|
||||||
Formatted context string for injection into the agent's system prompt.
|
Formatted context string for injection into the agent's system prompt.
|
||||||
"""
|
"""
|
||||||
queryVector = await self._embedSingle(currentPrompt)
|
queryText = _extractUserQuery(currentPrompt)
|
||||||
|
queryVector = await self._embedSingle(queryText) if queryText else []
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"buildAgentContext.start userId=%s featureInstanceId=%s mandateId=%s isSysAdmin=%s prompt=%r",
|
"buildAgentContext.start userId=%s featureInstanceId=%s mandateId=%s isSysAdmin=%s prompt=%r",
|
||||||
userId, featureInstanceId, mandateId, isSysAdmin, (currentPrompt or "")[:120],
|
userId, featureInstanceId, mandateId, isSysAdmin, (currentPrompt or "")[:120],
|
||||||
|
|
@ -960,6 +961,29 @@ class KnowledgeService:
|
||||||
# Internal helpers
|
# Internal helpers
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
# Markers added by the prompt enrichment layers (_enrichPromptWithFiles in
|
||||||
|
# mainServiceAgent, data source sections in routeFeatureWorkspace). Used to
|
||||||
|
# isolate the user's own words for the semantic search query.
|
||||||
|
_USER_REQUEST_MARKER = "\n\nUser request: "
|
||||||
|
_PROMPT_SECTION_MARKERS = ("\n\n[Active Data Sources]\n", "\n\n[Attached Feature Data Sources]\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _extractUserQuery(prompt: str) -> str:
|
||||||
|
"""Isolate the user's actual question from an enriched agent prompt.
|
||||||
|
|
||||||
|
Enriched prompts wrap the user request in file metadata headers and
|
||||||
|
data-source sections. Only the user's own words form a meaningful semantic
|
||||||
|
search query - embedding the metadata would dilute the vector and can
|
||||||
|
exceed the embedding model's input limit.
|
||||||
|
"""
|
||||||
|
if not prompt:
|
||||||
|
return ""
|
||||||
|
text = prompt.rsplit(_USER_REQUEST_MARKER, 1)[-1]
|
||||||
|
for marker in _PROMPT_SECTION_MARKERS:
|
||||||
|
text = text.split(marker, 1)[0]
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
def _estimateTokens(text: str) -> int:
|
def _estimateTokens(text: str) -> int:
|
||||||
"""Estimate token count using character-based heuristic (1 token ~ 4 chars)."""
|
"""Estimate token count using character-based heuristic (1 token ~ 4 chars)."""
|
||||||
return max(1, len(text) // CHARS_PER_TOKEN)
|
return max(1, len(text) // CHARS_PER_TOKEN)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue