PEK update

This commit is contained in:
Ida Dittrich 2025-12-15 07:13:17 +01:00
parent ad14e272f3
commit 12dd5aaea6
8 changed files with 400 additions and 40 deletions

View file

@ -78,3 +78,28 @@ class PaginatedResponse(BaseModel, Generic[T]):
model_config = ConfigDict(arbitrary_types_allowed=True)
def normalize_pagination_dict(pagination_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Normalize pagination dictionary to handle frontend variations.
Moves top-level "search" field into filters if present.
Args:
pagination_dict: Raw pagination dictionary from frontend
Returns:
Normalized pagination dictionary ready for PaginationParams parsing
"""
if not pagination_dict:
return pagination_dict
# Create a copy to avoid modifying the original
normalized = dict(pagination_dict)
# Move top-level "search" into filters if present
if "search" in normalized:
if "filters" not in normalized or normalized["filters"] is None:
normalized["filters"] = {}
normalized["filters"]["search"] = normalized.pop("search")
return normalized

View file

@ -384,12 +384,14 @@ class ChatObjects:
matches = True
# Handle general search across text fields
if "search" in filters:
search_term = str(filters["search"]).lower()
if "search" in filters and filters["search"] is not None:
search_term = str(filters["search"]).strip().lower()
if search_term:
# Search in all string fields
found = False
for key, value in record.items():
if value is None:
continue
if isinstance(value, str) and search_term in value.lower():
found = True
break
@ -404,17 +406,39 @@ class ChatObjects:
if field_name == "search":
continue # Already handled above
# Skip None or empty filter values (but empty string for strings should still filter)
if filter_value is None:
continue
# If field doesn't exist in record, reject this record for this filter
if field_name not in record:
matches = False
break
record_value = record.get(field_name)
# Handle simple value (equals operator)
# Handle simple value - default to "contains" for strings, "equals" for other types
if not isinstance(filter_value, dict):
if record_value != filter_value:
# Skip None values in record
if record_value is None:
matches = False
break
# For string fields, default to "contains" for more intuitive filtering
if isinstance(record_value, str) and isinstance(filter_value, str):
# Skip empty filter strings
filter_str = str(filter_value).strip().lower()
if not filter_str:
continue
record_str = record_value.lower()
if filter_str not in record_str:
matches = False
break
else:
# For non-string fields, use exact match
if record_value != filter_value:
matches = False
break
continue
# Handle filter with operator

View file

@ -308,12 +308,14 @@ class ComponentObjects:
matches = True
# Handle general search across text fields
if "search" in filters:
search_term = str(filters["search"]).lower()
if "search" in filters and filters["search"] is not None:
search_term = str(filters["search"]).strip().lower()
if search_term:
# Search in all string fields
found = False
for key, value in record.items():
if value is None:
continue
if isinstance(value, str) and search_term in value.lower():
found = True
break
@ -322,23 +324,72 @@ class ComponentObjects:
break
if not found:
matches = False
# Continue checking other filters, but this record won't match
# Handle field-specific filters
for field_name, filter_value in filters.items():
if field_name == "search":
continue # Already handled above
# Skip None or empty filter values (but empty string for strings should still filter)
if filter_value is None:
continue
# If field doesn't exist in record, reject this record for this filter
if field_name not in record:
matches = False
break
record_value = record.get(field_name)
# Handle simple value (equals operator)
# Handle simple value - default to "contains" for strings, "equals" for other types
if not isinstance(filter_value, dict):
if record_value != filter_value:
# Skip None values in record
if record_value is None:
matches = False
break
# Special handling for fileSize field - parse formatted size strings
if field_name == "fileSize" and isinstance(filter_value, str):
try:
# Parse formatted size string (e.g., "2.13 MB", "1.5 GB", "500 KB")
filter_size_bytes = self._parse_size_string(filter_value)
if filter_size_bytes is not None:
# Compare as integers (bytes)
if isinstance(record_value, (int, float)):
# Allow small tolerance for rounding differences (5% or 50KB, whichever is smaller)
# This accounts for formatting differences but prevents matching very different sizes
tolerance = min(filter_size_bytes * 0.05, 50 * 1024)
if abs(record_value - filter_size_bytes) > tolerance:
matches = False
break
else:
matches = False
break
else:
# If parsing fails, try string contains match
if str(record_value) not in filter_value:
matches = False
break
except Exception:
# If parsing fails, skip this filter
continue
# For string fields, default to "contains" for more intuitive filtering
elif isinstance(record_value, str) and isinstance(filter_value, str):
# Skip empty filter strings
filter_str = str(filter_value).strip().lower()
if not filter_str:
continue
record_str = record_value.lower()
if filter_str not in record_str:
matches = False
break
else:
# For non-string fields, use exact match
if record_value != filter_value:
matches = False
break
continue
# Handle filter with operator
@ -491,6 +542,49 @@ class ComponentObjects:
def getInitialId(self, model_class: type) -> Optional[str]:
"""Returns the initial ID for a table."""
return self.db.getInitialId(model_class)
def _parse_size_string(self, size_str: str) -> Optional[int]:
"""
Parse a formatted size string (e.g., "2.13 MB", "1.5 GB") to bytes.
Args:
size_str: Formatted size string like "2.13 MB", "1.5 GB", "500 KB"
Returns:
Size in bytes, or None if parsing fails
"""
try:
size_str = size_str.strip().upper()
# Remove common separators and spaces
size_str = size_str.replace(",", "").replace(" ", "")
# Extract number and unit - handle both "MB" and "M" formats
import re
# Match: number (with optional decimal) followed by optional unit (K/M/G/T with optional B)
match = re.match(r"^([\d.]+)([KMGT]?B?)$", size_str)
if not match:
return None
number = float(match.group(1))
unit = match.group(2) or "B"
# Normalize unit (handle "M" as "MB", "K" as "KB", etc.)
if len(unit) == 1 and unit in "KMGT":
unit = unit + "B"
# Convert to bytes
multipliers = {
"B": 1,
"KB": 1024,
"MB": 1024 * 1024,
"GB": 1024 * 1024 * 1024,
"TB": 1024 * 1024 * 1024 * 1024,
}
multiplier = multipliers.get(unit, 1)
return int(number * multiplier)
except Exception:
return None

View file

@ -8,17 +8,20 @@ SECURITY NOTE:
- This prevents security vulnerabilities where admin users could see other users' connections
"""
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response, Query
from typing import List, Dict, Any, Optional
from fastapi import status
import logging
import json
import math
from modules.datamodels.datamodelUam import User, UserConnection, AuthAuthority, ConnectionStatus
from modules.datamodels.datamodelSecurity import Token
from modules.auth import getCurrentUser, limiter
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
from modules.interfaces.interfaceDbAppObjects import getInterface
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
from modules.interfaces.interfaceDbComponentObjects import ComponentObjects
# Configure logger
logger = logging.getLogger(__name__)
@ -87,20 +90,44 @@ router = APIRouter(
responses={404: {"description": "Not found"}}
)
@router.get("/", response_model=List[UserConnection])
@router.get("/", response_model=PaginatedResponse[UserConnection])
@limiter.limit("30/minute")
async def get_connections(
request: Request,
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
currentUser: User = Depends(getCurrentUser)
) -> List[UserConnection]:
"""Get all connections for the current user
) -> PaginatedResponse[UserConnection]:
"""Get connections for the current user with optional pagination, sorting, and filtering.
SECURITY: This endpoint is secure - users can only see their own connections.
Automatically refreshes expired OAuth tokens in the background.
Query Parameters:
- pagination: JSON-encoded PaginationParams object, or None for no pagination
Examples:
- GET /api/connections/ (no pagination - returns all items)
- GET /api/connections/?pagination={"page":1,"pageSize":10,"sort":[]}
- GET /api/connections/?pagination={"page":1,"pageSize":10,"filters":{"status":"active"}}
"""
try:
interface = getInterface(currentUser)
# Parse pagination parameter
paginationParams = None
if pagination:
try:
paginationDict = json.loads(pagination)
if paginationDict:
# Normalize pagination dict (handles top-level "search" field)
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(
status_code=400,
detail=f"Invalid pagination parameter: {str(e)}"
)
# SECURITY FIX: All users (including admins) can only see their own connections
# This prevents admin from seeing other users' connections and causing confusion
connections = interface.getUserConnections(currentUser.id)
@ -117,33 +144,111 @@ async def get_connections(
logger.warning(f"Silent token refresh failed for user {currentUser.id}: {str(e)}")
# Continue with original connections even if refresh fails
# Enhance each connection with token status information
enhanced_connections = []
# Enhance each connection with token status information and convert to dict
enhanced_connections_dict = []
for connection in connections:
# Get token status for this connection
tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connection.id)
# Create enhanced connection with token status
enhanced_connection = UserConnection(
id=connection.id,
userId=connection.userId,
authority=connection.authority,
externalId=connection.externalId,
externalUsername=connection.externalUsername,
externalEmail=connection.externalEmail,
status=connection.status,
connectedAt=connection.connectedAt,
lastChecked=connection.lastChecked,
expiresAt=connection.expiresAt,
tokenStatus=tokenStatus,
tokenExpiresAt=tokenExpiresAt
# Convert to dict for filtering/sorting
connection_dict = {
"id": connection.id,
"userId": connection.userId,
"authority": connection.authority.value if hasattr(connection.authority, 'value') else str(connection.authority),
"externalId": connection.externalId,
"externalUsername": connection.externalUsername or "",
"externalEmail": connection.externalEmail, # Keep None instead of converting to empty string
"status": connection.status.value if hasattr(connection.status, 'value') else str(connection.status),
"connectedAt": connection.connectedAt,
"lastChecked": connection.lastChecked,
"expiresAt": connection.expiresAt,
"tokenStatus": tokenStatus,
"tokenExpiresAt": tokenExpiresAt
}
enhanced_connections_dict.append(connection_dict)
# If no pagination requested, return all items
if paginationParams is None:
# Convert back to UserConnection objects (enum strings are already in dict)
items = []
for conn_dict in enhanced_connections_dict:
conn_dict_copy = dict(conn_dict)
if "authority" in conn_dict_copy and isinstance(conn_dict_copy["authority"], str):
try:
conn_dict_copy["authority"] = AuthAuthority(conn_dict_copy["authority"])
except ValueError:
pass
if "status" in conn_dict_copy and isinstance(conn_dict_copy["status"], str):
try:
conn_dict_copy["status"] = ConnectionStatus(conn_dict_copy["status"])
except ValueError:
pass
items.append(UserConnection(**conn_dict_copy))
return PaginatedResponse(
items=items,
pagination=None
)
enhanced_connections.append(enhanced_connection)
return enhanced_connections
# Apply filtering if provided
if paginationParams.filters:
component_interface = ComponentObjects()
component_interface.setUserContext(currentUser)
enhanced_connections_dict = component_interface._applyFilters(
enhanced_connections_dict,
paginationParams.filters
)
# Apply sorting if provided
if paginationParams.sort:
component_interface = ComponentObjects()
component_interface.setUserContext(currentUser)
enhanced_connections_dict = component_interface._applySorting(
enhanced_connections_dict,
paginationParams.sort
)
# Count total items after filters
totalItems = len(enhanced_connections_dict)
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
# Apply pagination (skip/limit)
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
endIdx = startIdx + paginationParams.pageSize
paged_connections = enhanced_connections_dict[startIdx:endIdx]
# Convert back to UserConnection objects (convert enum strings back to enums)
items = []
for conn_dict in paged_connections:
# Convert enum strings back to enum objects
conn_dict_copy = dict(conn_dict)
if "authority" in conn_dict_copy and isinstance(conn_dict_copy["authority"], str):
try:
conn_dict_copy["authority"] = AuthAuthority(conn_dict_copy["authority"])
except ValueError:
pass # Keep as string if invalid
if "status" in conn_dict_copy and isinstance(conn_dict_copy["status"], str):
try:
conn_dict_copy["status"] = ConnectionStatus(conn_dict_copy["status"])
except ValueError:
pass # Keep as string if invalid
items.append(UserConnection(**conn_dict_copy))
return PaginatedResponse(
items=items,
pagination=PaginationMetadata(
currentPage=paginationParams.page,
pageSize=paginationParams.pageSize,
totalItems=totalItems,
totalPages=totalPages,
sort=paginationParams.sort,
filters=paginationParams.filters
)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting connections: {str(e)}")
logger.error(f"Error getting connections: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get connections: {str(e)}"

View file

@ -12,7 +12,7 @@ import modules.interfaces.interfaceDbComponentObjects as interfaceDbComponentObj
from modules.datamodels.datamodelFiles import FileItem, FilePreview
from modules.shared.attributeUtils import getModelAttributeDefinitions
from modules.datamodels.datamodelUam import User
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
# Configure logger
logger = logging.getLogger(__name__)
@ -57,7 +57,10 @@ async def get_files(
if pagination:
try:
paginationDict = json.loads(pagination)
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
if paginationDict:
# Normalize pagination dict (handles top-level "search" field)
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(
status_code=400,

View file

@ -11,7 +11,7 @@ from modules.auth import limiter, getCurrentUser
import modules.interfaces.interfaceDbComponentObjects as interfaceDbComponentObjects
from modules.datamodels.datamodelUtils import Prompt
from modules.datamodels.datamodelUam import User
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
# Configure logger
logger = logging.getLogger(__name__)
@ -46,7 +46,10 @@ async def get_prompts(
if pagination:
try:
paginationDict = json.loads(pagination)
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
if paginationDict:
# Normalize pagination dict (handles top-level "search" field)
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(
status_code=400,

View file

@ -26,7 +26,7 @@ from modules.datamodels.datamodelChat import (
)
from modules.shared.attributeUtils import getModelAttributeDefinitions, AttributeResponse
from modules.datamodels.datamodelUam import User
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
# Configure logger
@ -69,7 +69,10 @@ async def get_workflows(
if pagination:
try:
paginationDict = json.loads(pagination)
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
if paginationDict:
# Normalize pagination dict (handles top-level "search" field)
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(
status_code=400,
@ -262,7 +265,10 @@ async def get_workflow_logs(
if pagination:
try:
paginationDict = json.loads(pagination)
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
if paginationDict:
# Normalize pagination dict (handles top-level "search" field)
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(
status_code=400,
@ -350,7 +356,10 @@ async def get_workflow_messages(
if pagination:
try:
paginationDict = json.loads(pagination)
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
if paginationDict:
# Normalize pagination dict (handles top-level "search" field)
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(
status_code=400,

View file

@ -227,11 +227,28 @@ class WorkflowManager:
workflow = self.services.workflow
checkWorkflowStopped(self.services)
# Log fast path start
self.services.chat.storeLog(workflow, {
"message": "Fast path execution started",
"type": "info",
"status": "running",
"progress": 0.1
})
# Get user language if available
userLanguage = getattr(self.services, 'currentUserLanguage', None)
# Execute fast path - use normalizedRequest if available, otherwise use raw prompt
normalizedPrompt = getattr(self.services, 'currentUserPromptNormalized', None) or userInput.prompt
# Log fast path execution
self.services.chat.storeLog(workflow, {
"message": f"Processing request via fast path (language: {userLanguage or 'auto'})",
"type": "info",
"status": "running",
"progress": 0.3
})
result = await self.workflowProcessor.fastPathExecute(
prompt=normalizedPrompt,
documents=documents,
@ -241,6 +258,12 @@ class WorkflowManager:
if not result.success:
# Fast path failed, fall back to full workflow
logger.warning(f"Fast path failed: {result.error}, falling back to full workflow")
self.services.chat.storeLog(workflow, {
"message": f"Fast path failed: {result.error}. Falling back to full workflow.",
"type": "warning",
"status": "running",
"progress": 0.5
})
taskPlan = await self._planTasks(userInput)
await self._executeTasks(taskPlan)
await self._processWorkflowResults()
@ -288,7 +311,58 @@ class WorkflowManager:
}
chatDocuments.append(chatDoc)
# Mark workflow as completed BEFORE storing message (so UI polling stops)
# Create initial user message first
roundNum = workflow.currentRound
contextLabel = f"round{roundNum}_usercontext"
userMessageData = {
"workflowId": workflow.id,
"role": "user",
"message": userInput.prompt,
"status": "first",
"sequenceNr": 1,
"publishedAt": self.services.utils.timestampGetUtc(),
"documentsLabel": contextLabel,
"documents": [],
# Add workflow context fields
"roundNumber": workflow.currentRound,
"taskNumber": 0,
"actionNumber": 0,
# Add progress status
"taskProgress": "pending",
"actionProgress": "pending"
}
# Store user message (with any uploaded documents)
# Convert ChatDocument objects to dicts for storeMessageWithDocuments
userDocuments = []
for doc in documents:
if isinstance(doc, dict):
userDoc = dict(doc)
else:
# ChatDocument object - convert to dict
userDoc = {
"fileId": doc.fileId,
"fileName": doc.fileName,
"fileSize": doc.fileSize,
"mimeType": doc.mimeType,
"roundNumber": workflow.currentRound,
"taskNumber": 0,
"actionNumber": 0
}
userDocuments.append(userDoc)
self.services.chat.storeMessageWithDocuments(workflow, userMessageData, userDocuments)
# Log user message stored
self.services.chat.storeLog(workflow, {
"message": "User message stored",
"type": "info",
"status": "running",
"progress": 0.6
})
# Mark workflow as completed BEFORE storing response message (so UI polling stops)
workflow.status = "completed"
workflow.lastActivity = self.services.utils.timestampGetUtc()
self.services.chat.updateWorkflow(workflow.id, {
@ -296,6 +370,14 @@ class WorkflowManager:
"lastActivity": workflow.lastActivity
})
# Log response generation
self.services.chat.storeLog(workflow, {
"message": "Generating fast path response",
"type": "info",
"status": "running",
"progress": 0.8
})
# Create ChatMessage with fast path response (in user's language)
messageData = {
"workflowId": workflow.id,
@ -318,10 +400,25 @@ class WorkflowManager:
# Store message with documents
self.services.chat.storeMessageWithDocuments(workflow, messageData, chatDocuments)
# Log fast path completion
self.services.chat.storeLog(workflow, {
"message": f"Fast path completed successfully (response length: {len(responseText)} chars)",
"type": "info",
"status": "completed",
"progress": 1.0
})
logger.info(f"Fast path completed successfully, response length: {len(responseText)} chars")
except Exception as e:
logger.error(f"Error in _executeFastPath: {str(e)}")
# Log error
self.services.chat.storeLog(workflow, {
"message": f"Fast path error: {str(e)}. Falling back to full workflow.",
"type": "error",
"status": "running",
"progress": 0.5
})
# Fall back to full workflow on error
logger.info("Falling back to full workflow due to fast path error")
taskPlan = await self._planTasks(userInput)