PEK update
This commit is contained in:
parent
ad14e272f3
commit
12dd5aaea6
8 changed files with 400 additions and 40 deletions
|
|
@ -78,3 +78,28 @@ class PaginatedResponse(BaseModel, Generic[T]):
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_pagination_dict(pagination_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Normalize pagination dictionary to handle frontend variations.
|
||||||
|
Moves top-level "search" field into filters if present.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pagination_dict: Raw pagination dictionary from frontend
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized pagination dictionary ready for PaginationParams parsing
|
||||||
|
"""
|
||||||
|
if not pagination_dict:
|
||||||
|
return pagination_dict
|
||||||
|
|
||||||
|
# Create a copy to avoid modifying the original
|
||||||
|
normalized = dict(pagination_dict)
|
||||||
|
|
||||||
|
# Move top-level "search" into filters if present
|
||||||
|
if "search" in normalized:
|
||||||
|
if "filters" not in normalized or normalized["filters"] is None:
|
||||||
|
normalized["filters"] = {}
|
||||||
|
normalized["filters"]["search"] = normalized.pop("search")
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
|
||||||
|
|
@ -384,12 +384,14 @@ class ChatObjects:
|
||||||
matches = True
|
matches = True
|
||||||
|
|
||||||
# Handle general search across text fields
|
# Handle general search across text fields
|
||||||
if "search" in filters:
|
if "search" in filters and filters["search"] is not None:
|
||||||
search_term = str(filters["search"]).lower()
|
search_term = str(filters["search"]).strip().lower()
|
||||||
if search_term:
|
if search_term:
|
||||||
# Search in all string fields
|
# Search in all string fields
|
||||||
found = False
|
found = False
|
||||||
for key, value in record.items():
|
for key, value in record.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
if isinstance(value, str) and search_term in value.lower():
|
if isinstance(value, str) and search_term in value.lower():
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
|
|
@ -404,14 +406,36 @@ class ChatObjects:
|
||||||
if field_name == "search":
|
if field_name == "search":
|
||||||
continue # Already handled above
|
continue # Already handled above
|
||||||
|
|
||||||
|
# Skip None or empty filter values (but empty string for strings should still filter)
|
||||||
|
if filter_value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If field doesn't exist in record, reject this record for this filter
|
||||||
if field_name not in record:
|
if field_name not in record:
|
||||||
matches = False
|
matches = False
|
||||||
break
|
break
|
||||||
|
|
||||||
record_value = record.get(field_name)
|
record_value = record.get(field_name)
|
||||||
|
|
||||||
# Handle simple value (equals operator)
|
# Handle simple value - default to "contains" for strings, "equals" for other types
|
||||||
if not isinstance(filter_value, dict):
|
if not isinstance(filter_value, dict):
|
||||||
|
# Skip None values in record
|
||||||
|
if record_value is None:
|
||||||
|
matches = False
|
||||||
|
break
|
||||||
|
|
||||||
|
# For string fields, default to "contains" for more intuitive filtering
|
||||||
|
if isinstance(record_value, str) and isinstance(filter_value, str):
|
||||||
|
# Skip empty filter strings
|
||||||
|
filter_str = str(filter_value).strip().lower()
|
||||||
|
if not filter_str:
|
||||||
|
continue
|
||||||
|
record_str = record_value.lower()
|
||||||
|
if filter_str not in record_str:
|
||||||
|
matches = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# For non-string fields, use exact match
|
||||||
if record_value != filter_value:
|
if record_value != filter_value:
|
||||||
matches = False
|
matches = False
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -308,12 +308,14 @@ class ComponentObjects:
|
||||||
matches = True
|
matches = True
|
||||||
|
|
||||||
# Handle general search across text fields
|
# Handle general search across text fields
|
||||||
if "search" in filters:
|
if "search" in filters and filters["search"] is not None:
|
||||||
search_term = str(filters["search"]).lower()
|
search_term = str(filters["search"]).strip().lower()
|
||||||
if search_term:
|
if search_term:
|
||||||
# Search in all string fields
|
# Search in all string fields
|
||||||
found = False
|
found = False
|
||||||
for key, value in record.items():
|
for key, value in record.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
if isinstance(value, str) and search_term in value.lower():
|
if isinstance(value, str) and search_term in value.lower():
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
|
|
@ -322,20 +324,69 @@ class ComponentObjects:
|
||||||
break
|
break
|
||||||
if not found:
|
if not found:
|
||||||
matches = False
|
matches = False
|
||||||
|
# Continue checking other filters, but this record won't match
|
||||||
|
|
||||||
# Handle field-specific filters
|
# Handle field-specific filters
|
||||||
for field_name, filter_value in filters.items():
|
for field_name, filter_value in filters.items():
|
||||||
if field_name == "search":
|
if field_name == "search":
|
||||||
continue # Already handled above
|
continue # Already handled above
|
||||||
|
|
||||||
|
# Skip None or empty filter values (but empty string for strings should still filter)
|
||||||
|
if filter_value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If field doesn't exist in record, reject this record for this filter
|
||||||
if field_name not in record:
|
if field_name not in record:
|
||||||
matches = False
|
matches = False
|
||||||
break
|
break
|
||||||
|
|
||||||
record_value = record.get(field_name)
|
record_value = record.get(field_name)
|
||||||
|
|
||||||
# Handle simple value (equals operator)
|
# Handle simple value - default to "contains" for strings, "equals" for other types
|
||||||
if not isinstance(filter_value, dict):
|
if not isinstance(filter_value, dict):
|
||||||
|
# Skip None values in record
|
||||||
|
if record_value is None:
|
||||||
|
matches = False
|
||||||
|
break
|
||||||
|
|
||||||
|
# Special handling for fileSize field - parse formatted size strings
|
||||||
|
if field_name == "fileSize" and isinstance(filter_value, str):
|
||||||
|
try:
|
||||||
|
# Parse formatted size string (e.g., "2.13 MB", "1.5 GB", "500 KB")
|
||||||
|
filter_size_bytes = self._parse_size_string(filter_value)
|
||||||
|
if filter_size_bytes is not None:
|
||||||
|
# Compare as integers (bytes)
|
||||||
|
if isinstance(record_value, (int, float)):
|
||||||
|
# Allow small tolerance for rounding differences (5% or 50KB, whichever is smaller)
|
||||||
|
# This accounts for formatting differences but prevents matching very different sizes
|
||||||
|
tolerance = min(filter_size_bytes * 0.05, 50 * 1024)
|
||||||
|
if abs(record_value - filter_size_bytes) > tolerance:
|
||||||
|
matches = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
matches = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# If parsing fails, try string contains match
|
||||||
|
if str(record_value) not in filter_value:
|
||||||
|
matches = False
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
# If parsing fails, skip this filter
|
||||||
|
continue
|
||||||
|
|
||||||
|
# For string fields, default to "contains" for more intuitive filtering
|
||||||
|
elif isinstance(record_value, str) and isinstance(filter_value, str):
|
||||||
|
# Skip empty filter strings
|
||||||
|
filter_str = str(filter_value).strip().lower()
|
||||||
|
if not filter_str:
|
||||||
|
continue
|
||||||
|
record_str = record_value.lower()
|
||||||
|
if filter_str not in record_str:
|
||||||
|
matches = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# For non-string fields, use exact match
|
||||||
if record_value != filter_value:
|
if record_value != filter_value:
|
||||||
matches = False
|
matches = False
|
||||||
break
|
break
|
||||||
|
|
@ -492,6 +543,49 @@ class ComponentObjects:
|
||||||
"""Returns the initial ID for a table."""
|
"""Returns the initial ID for a table."""
|
||||||
return self.db.getInitialId(model_class)
|
return self.db.getInitialId(model_class)
|
||||||
|
|
||||||
|
def _parse_size_string(self, size_str: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Parse a formatted size string (e.g., "2.13 MB", "1.5 GB") to bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size_str: Formatted size string like "2.13 MB", "1.5 GB", "500 KB"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Size in bytes, or None if parsing fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
size_str = size_str.strip().upper()
|
||||||
|
# Remove common separators and spaces
|
||||||
|
size_str = size_str.replace(",", "").replace(" ", "")
|
||||||
|
|
||||||
|
# Extract number and unit - handle both "MB" and "M" formats
|
||||||
|
import re
|
||||||
|
# Match: number (with optional decimal) followed by optional unit (K/M/G/T with optional B)
|
||||||
|
match = re.match(r"^([\d.]+)([KMGT]?B?)$", size_str)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
|
||||||
|
number = float(match.group(1))
|
||||||
|
unit = match.group(2) or "B"
|
||||||
|
|
||||||
|
# Normalize unit (handle "M" as "MB", "K" as "KB", etc.)
|
||||||
|
if len(unit) == 1 and unit in "KMGT":
|
||||||
|
unit = unit + "B"
|
||||||
|
|
||||||
|
# Convert to bytes
|
||||||
|
multipliers = {
|
||||||
|
"B": 1,
|
||||||
|
"KB": 1024,
|
||||||
|
"MB": 1024 * 1024,
|
||||||
|
"GB": 1024 * 1024 * 1024,
|
||||||
|
"TB": 1024 * 1024 * 1024 * 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
multiplier = multipliers.get(unit, 1)
|
||||||
|
return int(number * multiplier)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Prompt methods
|
# Prompt methods
|
||||||
|
|
|
||||||
|
|
@ -8,17 +8,20 @@ SECURITY NOTE:
|
||||||
- This prevents security vulnerabilities where admin users could see other users' connections
|
- This prevents security vulnerabilities where admin users could see other users' connections
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response
|
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response, Query
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
|
|
||||||
from modules.datamodels.datamodelUam import User, UserConnection, AuthAuthority, ConnectionStatus
|
from modules.datamodels.datamodelUam import User, UserConnection, AuthAuthority, ConnectionStatus
|
||||||
from modules.datamodels.datamodelSecurity import Token
|
from modules.datamodels.datamodelSecurity import Token
|
||||||
from modules.auth import getCurrentUser, limiter
|
from modules.auth import getCurrentUser, limiter
|
||||||
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||||
from modules.interfaces.interfaceDbAppObjects import getInterface
|
from modules.interfaces.interfaceDbAppObjects import getInterface
|
||||||
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
|
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
|
||||||
|
from modules.interfaces.interfaceDbComponentObjects import ComponentObjects
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -87,20 +90,44 @@ router = APIRouter(
|
||||||
responses={404: {"description": "Not found"}}
|
responses={404: {"description": "Not found"}}
|
||||||
)
|
)
|
||||||
|
|
||||||
@router.get("/", response_model=List[UserConnection])
|
@router.get("/", response_model=PaginatedResponse[UserConnection])
|
||||||
@limiter.limit("30/minute")
|
@limiter.limit("30/minute")
|
||||||
async def get_connections(
|
async def get_connections(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||||
currentUser: User = Depends(getCurrentUser)
|
currentUser: User = Depends(getCurrentUser)
|
||||||
) -> List[UserConnection]:
|
) -> PaginatedResponse[UserConnection]:
|
||||||
"""Get all connections for the current user
|
"""Get connections for the current user with optional pagination, sorting, and filtering.
|
||||||
|
|
||||||
SECURITY: This endpoint is secure - users can only see their own connections.
|
SECURITY: This endpoint is secure - users can only see their own connections.
|
||||||
Automatically refreshes expired OAuth tokens in the background.
|
Automatically refreshes expired OAuth tokens in the background.
|
||||||
|
|
||||||
|
Query Parameters:
|
||||||
|
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- GET /api/connections/ (no pagination - returns all items)
|
||||||
|
- GET /api/connections/?pagination={"page":1,"pageSize":10,"sort":[]}
|
||||||
|
- GET /api/connections/?pagination={"page":1,"pageSize":10,"filters":{"status":"active"}}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
interface = getInterface(currentUser)
|
interface = getInterface(currentUser)
|
||||||
|
|
||||||
|
# Parse pagination parameter
|
||||||
|
paginationParams = None
|
||||||
|
if pagination:
|
||||||
|
try:
|
||||||
|
paginationDict = json.loads(pagination)
|
||||||
|
if paginationDict:
|
||||||
|
# Normalize pagination dict (handles top-level "search" field)
|
||||||
|
paginationDict = normalize_pagination_dict(paginationDict)
|
||||||
|
paginationParams = PaginationParams(**paginationDict)
|
||||||
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid pagination parameter: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
# SECURITY FIX: All users (including admins) can only see their own connections
|
# SECURITY FIX: All users (including admins) can only see their own connections
|
||||||
# This prevents admin from seeing other users' connections and causing confusion
|
# This prevents admin from seeing other users' connections and causing confusion
|
||||||
connections = interface.getUserConnections(currentUser.id)
|
connections = interface.getUserConnections(currentUser.id)
|
||||||
|
|
@ -117,33 +144,111 @@ async def get_connections(
|
||||||
logger.warning(f"Silent token refresh failed for user {currentUser.id}: {str(e)}")
|
logger.warning(f"Silent token refresh failed for user {currentUser.id}: {str(e)}")
|
||||||
# Continue with original connections even if refresh fails
|
# Continue with original connections even if refresh fails
|
||||||
|
|
||||||
# Enhance each connection with token status information
|
# Enhance each connection with token status information and convert to dict
|
||||||
enhanced_connections = []
|
enhanced_connections_dict = []
|
||||||
for connection in connections:
|
for connection in connections:
|
||||||
# Get token status for this connection
|
# Get token status for this connection
|
||||||
tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connection.id)
|
tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connection.id)
|
||||||
|
|
||||||
# Create enhanced connection with token status
|
# Convert to dict for filtering/sorting
|
||||||
enhanced_connection = UserConnection(
|
connection_dict = {
|
||||||
id=connection.id,
|
"id": connection.id,
|
||||||
userId=connection.userId,
|
"userId": connection.userId,
|
||||||
authority=connection.authority,
|
"authority": connection.authority.value if hasattr(connection.authority, 'value') else str(connection.authority),
|
||||||
externalId=connection.externalId,
|
"externalId": connection.externalId,
|
||||||
externalUsername=connection.externalUsername,
|
"externalUsername": connection.externalUsername or "",
|
||||||
externalEmail=connection.externalEmail,
|
"externalEmail": connection.externalEmail, # Keep None instead of converting to empty string
|
||||||
status=connection.status,
|
"status": connection.status.value if hasattr(connection.status, 'value') else str(connection.status),
|
||||||
connectedAt=connection.connectedAt,
|
"connectedAt": connection.connectedAt,
|
||||||
lastChecked=connection.lastChecked,
|
"lastChecked": connection.lastChecked,
|
||||||
expiresAt=connection.expiresAt,
|
"expiresAt": connection.expiresAt,
|
||||||
tokenStatus=tokenStatus,
|
"tokenStatus": tokenStatus,
|
||||||
tokenExpiresAt=tokenExpiresAt
|
"tokenExpiresAt": tokenExpiresAt
|
||||||
|
}
|
||||||
|
enhanced_connections_dict.append(connection_dict)
|
||||||
|
|
||||||
|
# If no pagination requested, return all items
|
||||||
|
if paginationParams is None:
|
||||||
|
# Convert back to UserConnection objects (enum strings are already in dict)
|
||||||
|
items = []
|
||||||
|
for conn_dict in enhanced_connections_dict:
|
||||||
|
conn_dict_copy = dict(conn_dict)
|
||||||
|
if "authority" in conn_dict_copy and isinstance(conn_dict_copy["authority"], str):
|
||||||
|
try:
|
||||||
|
conn_dict_copy["authority"] = AuthAuthority(conn_dict_copy["authority"])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
if "status" in conn_dict_copy and isinstance(conn_dict_copy["status"], str):
|
||||||
|
try:
|
||||||
|
conn_dict_copy["status"] = ConnectionStatus(conn_dict_copy["status"])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
items.append(UserConnection(**conn_dict_copy))
|
||||||
|
return PaginatedResponse(
|
||||||
|
items=items,
|
||||||
|
pagination=None
|
||||||
)
|
)
|
||||||
enhanced_connections.append(enhanced_connection)
|
|
||||||
|
|
||||||
return enhanced_connections
|
# Apply filtering if provided
|
||||||
|
if paginationParams.filters:
|
||||||
|
component_interface = ComponentObjects()
|
||||||
|
component_interface.setUserContext(currentUser)
|
||||||
|
enhanced_connections_dict = component_interface._applyFilters(
|
||||||
|
enhanced_connections_dict,
|
||||||
|
paginationParams.filters
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply sorting if provided
|
||||||
|
if paginationParams.sort:
|
||||||
|
component_interface = ComponentObjects()
|
||||||
|
component_interface.setUserContext(currentUser)
|
||||||
|
enhanced_connections_dict = component_interface._applySorting(
|
||||||
|
enhanced_connections_dict,
|
||||||
|
paginationParams.sort
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count total items after filters
|
||||||
|
totalItems = len(enhanced_connections_dict)
|
||||||
|
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
||||||
|
|
||||||
|
# Apply pagination (skip/limit)
|
||||||
|
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
||||||
|
endIdx = startIdx + paginationParams.pageSize
|
||||||
|
paged_connections = enhanced_connections_dict[startIdx:endIdx]
|
||||||
|
|
||||||
|
# Convert back to UserConnection objects (convert enum strings back to enums)
|
||||||
|
items = []
|
||||||
|
for conn_dict in paged_connections:
|
||||||
|
# Convert enum strings back to enum objects
|
||||||
|
conn_dict_copy = dict(conn_dict)
|
||||||
|
if "authority" in conn_dict_copy and isinstance(conn_dict_copy["authority"], str):
|
||||||
|
try:
|
||||||
|
conn_dict_copy["authority"] = AuthAuthority(conn_dict_copy["authority"])
|
||||||
|
except ValueError:
|
||||||
|
pass # Keep as string if invalid
|
||||||
|
if "status" in conn_dict_copy and isinstance(conn_dict_copy["status"], str):
|
||||||
|
try:
|
||||||
|
conn_dict_copy["status"] = ConnectionStatus(conn_dict_copy["status"])
|
||||||
|
except ValueError:
|
||||||
|
pass # Keep as string if invalid
|
||||||
|
items.append(UserConnection(**conn_dict_copy))
|
||||||
|
|
||||||
|
return PaginatedResponse(
|
||||||
|
items=items,
|
||||||
|
pagination=PaginationMetadata(
|
||||||
|
currentPage=paginationParams.page,
|
||||||
|
pageSize=paginationParams.pageSize,
|
||||||
|
totalItems=totalItems,
|
||||||
|
totalPages=totalPages,
|
||||||
|
sort=paginationParams.sort,
|
||||||
|
filters=paginationParams.filters
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting connections: {str(e)}")
|
logger.error(f"Error getting connections: {str(e)}", exc_info=True)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Failed to get connections: {str(e)}"
|
detail=f"Failed to get connections: {str(e)}"
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import modules.interfaces.interfaceDbComponentObjects as interfaceDbComponentObj
|
||||||
from modules.datamodels.datamodelFiles import FileItem, FilePreview
|
from modules.datamodels.datamodelFiles import FileItem, FilePreview
|
||||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||||
from modules.datamodels.datamodelUam import User
|
from modules.datamodels.datamodelUam import User
|
||||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -57,7 +57,10 @@ async def get_files(
|
||||||
if pagination:
|
if pagination:
|
||||||
try:
|
try:
|
||||||
paginationDict = json.loads(pagination)
|
paginationDict = json.loads(pagination)
|
||||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
if paginationDict:
|
||||||
|
# Normalize pagination dict (handles top-level "search" field)
|
||||||
|
paginationDict = normalize_pagination_dict(paginationDict)
|
||||||
|
paginationParams = PaginationParams(**paginationDict)
|
||||||
except (json.JSONDecodeError, ValueError) as e:
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from modules.auth import limiter, getCurrentUser
|
||||||
import modules.interfaces.interfaceDbComponentObjects as interfaceDbComponentObjects
|
import modules.interfaces.interfaceDbComponentObjects as interfaceDbComponentObjects
|
||||||
from modules.datamodels.datamodelUtils import Prompt
|
from modules.datamodels.datamodelUtils import Prompt
|
||||||
from modules.datamodels.datamodelUam import User
|
from modules.datamodels.datamodelUam import User
|
||||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -46,7 +46,10 @@ async def get_prompts(
|
||||||
if pagination:
|
if pagination:
|
||||||
try:
|
try:
|
||||||
paginationDict = json.loads(pagination)
|
paginationDict = json.loads(pagination)
|
||||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
if paginationDict:
|
||||||
|
# Normalize pagination dict (handles top-level "search" field)
|
||||||
|
paginationDict = normalize_pagination_dict(paginationDict)
|
||||||
|
paginationParams = PaginationParams(**paginationDict)
|
||||||
except (json.JSONDecodeError, ValueError) as e:
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from modules.datamodels.datamodelChat import (
|
||||||
)
|
)
|
||||||
from modules.shared.attributeUtils import getModelAttributeDefinitions, AttributeResponse
|
from modules.shared.attributeUtils import getModelAttributeDefinitions, AttributeResponse
|
||||||
from modules.datamodels.datamodelUam import User
|
from modules.datamodels.datamodelUam import User
|
||||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||||
|
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
|
|
@ -69,7 +69,10 @@ async def get_workflows(
|
||||||
if pagination:
|
if pagination:
|
||||||
try:
|
try:
|
||||||
paginationDict = json.loads(pagination)
|
paginationDict = json.loads(pagination)
|
||||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
if paginationDict:
|
||||||
|
# Normalize pagination dict (handles top-level "search" field)
|
||||||
|
paginationDict = normalize_pagination_dict(paginationDict)
|
||||||
|
paginationParams = PaginationParams(**paginationDict)
|
||||||
except (json.JSONDecodeError, ValueError) as e:
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
|
|
@ -262,7 +265,10 @@ async def get_workflow_logs(
|
||||||
if pagination:
|
if pagination:
|
||||||
try:
|
try:
|
||||||
paginationDict = json.loads(pagination)
|
paginationDict = json.loads(pagination)
|
||||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
if paginationDict:
|
||||||
|
# Normalize pagination dict (handles top-level "search" field)
|
||||||
|
paginationDict = normalize_pagination_dict(paginationDict)
|
||||||
|
paginationParams = PaginationParams(**paginationDict)
|
||||||
except (json.JSONDecodeError, ValueError) as e:
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
|
|
@ -350,7 +356,10 @@ async def get_workflow_messages(
|
||||||
if pagination:
|
if pagination:
|
||||||
try:
|
try:
|
||||||
paginationDict = json.loads(pagination)
|
paginationDict = json.loads(pagination)
|
||||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
if paginationDict:
|
||||||
|
# Normalize pagination dict (handles top-level "search" field)
|
||||||
|
paginationDict = normalize_pagination_dict(paginationDict)
|
||||||
|
paginationParams = PaginationParams(**paginationDict)
|
||||||
except (json.JSONDecodeError, ValueError) as e:
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
|
|
|
||||||
|
|
@ -227,11 +227,28 @@ class WorkflowManager:
|
||||||
workflow = self.services.workflow
|
workflow = self.services.workflow
|
||||||
checkWorkflowStopped(self.services)
|
checkWorkflowStopped(self.services)
|
||||||
|
|
||||||
|
# Log fast path start
|
||||||
|
self.services.chat.storeLog(workflow, {
|
||||||
|
"message": "Fast path execution started",
|
||||||
|
"type": "info",
|
||||||
|
"status": "running",
|
||||||
|
"progress": 0.1
|
||||||
|
})
|
||||||
|
|
||||||
# Get user language if available
|
# Get user language if available
|
||||||
userLanguage = getattr(self.services, 'currentUserLanguage', None)
|
userLanguage = getattr(self.services, 'currentUserLanguage', None)
|
||||||
|
|
||||||
# Execute fast path - use normalizedRequest if available, otherwise use raw prompt
|
# Execute fast path - use normalizedRequest if available, otherwise use raw prompt
|
||||||
normalizedPrompt = getattr(self.services, 'currentUserPromptNormalized', None) or userInput.prompt
|
normalizedPrompt = getattr(self.services, 'currentUserPromptNormalized', None) or userInput.prompt
|
||||||
|
|
||||||
|
# Log fast path execution
|
||||||
|
self.services.chat.storeLog(workflow, {
|
||||||
|
"message": f"Processing request via fast path (language: {userLanguage or 'auto'})",
|
||||||
|
"type": "info",
|
||||||
|
"status": "running",
|
||||||
|
"progress": 0.3
|
||||||
|
})
|
||||||
|
|
||||||
result = await self.workflowProcessor.fastPathExecute(
|
result = await self.workflowProcessor.fastPathExecute(
|
||||||
prompt=normalizedPrompt,
|
prompt=normalizedPrompt,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
|
|
@ -241,6 +258,12 @@ class WorkflowManager:
|
||||||
if not result.success:
|
if not result.success:
|
||||||
# Fast path failed, fall back to full workflow
|
# Fast path failed, fall back to full workflow
|
||||||
logger.warning(f"Fast path failed: {result.error}, falling back to full workflow")
|
logger.warning(f"Fast path failed: {result.error}, falling back to full workflow")
|
||||||
|
self.services.chat.storeLog(workflow, {
|
||||||
|
"message": f"Fast path failed: {result.error}. Falling back to full workflow.",
|
||||||
|
"type": "warning",
|
||||||
|
"status": "running",
|
||||||
|
"progress": 0.5
|
||||||
|
})
|
||||||
taskPlan = await self._planTasks(userInput)
|
taskPlan = await self._planTasks(userInput)
|
||||||
await self._executeTasks(taskPlan)
|
await self._executeTasks(taskPlan)
|
||||||
await self._processWorkflowResults()
|
await self._processWorkflowResults()
|
||||||
|
|
@ -288,7 +311,58 @@ class WorkflowManager:
|
||||||
}
|
}
|
||||||
chatDocuments.append(chatDoc)
|
chatDocuments.append(chatDoc)
|
||||||
|
|
||||||
# Mark workflow as completed BEFORE storing message (so UI polling stops)
|
# Create initial user message first
|
||||||
|
roundNum = workflow.currentRound
|
||||||
|
contextLabel = f"round{roundNum}_usercontext"
|
||||||
|
|
||||||
|
userMessageData = {
|
||||||
|
"workflowId": workflow.id,
|
||||||
|
"role": "user",
|
||||||
|
"message": userInput.prompt,
|
||||||
|
"status": "first",
|
||||||
|
"sequenceNr": 1,
|
||||||
|
"publishedAt": self.services.utils.timestampGetUtc(),
|
||||||
|
"documentsLabel": contextLabel,
|
||||||
|
"documents": [],
|
||||||
|
# Add workflow context fields
|
||||||
|
"roundNumber": workflow.currentRound,
|
||||||
|
"taskNumber": 0,
|
||||||
|
"actionNumber": 0,
|
||||||
|
# Add progress status
|
||||||
|
"taskProgress": "pending",
|
||||||
|
"actionProgress": "pending"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store user message (with any uploaded documents)
|
||||||
|
# Convert ChatDocument objects to dicts for storeMessageWithDocuments
|
||||||
|
userDocuments = []
|
||||||
|
for doc in documents:
|
||||||
|
if isinstance(doc, dict):
|
||||||
|
userDoc = dict(doc)
|
||||||
|
else:
|
||||||
|
# ChatDocument object - convert to dict
|
||||||
|
userDoc = {
|
||||||
|
"fileId": doc.fileId,
|
||||||
|
"fileName": doc.fileName,
|
||||||
|
"fileSize": doc.fileSize,
|
||||||
|
"mimeType": doc.mimeType,
|
||||||
|
"roundNumber": workflow.currentRound,
|
||||||
|
"taskNumber": 0,
|
||||||
|
"actionNumber": 0
|
||||||
|
}
|
||||||
|
userDocuments.append(userDoc)
|
||||||
|
|
||||||
|
self.services.chat.storeMessageWithDocuments(workflow, userMessageData, userDocuments)
|
||||||
|
|
||||||
|
# Log user message stored
|
||||||
|
self.services.chat.storeLog(workflow, {
|
||||||
|
"message": "User message stored",
|
||||||
|
"type": "info",
|
||||||
|
"status": "running",
|
||||||
|
"progress": 0.6
|
||||||
|
})
|
||||||
|
|
||||||
|
# Mark workflow as completed BEFORE storing response message (so UI polling stops)
|
||||||
workflow.status = "completed"
|
workflow.status = "completed"
|
||||||
workflow.lastActivity = self.services.utils.timestampGetUtc()
|
workflow.lastActivity = self.services.utils.timestampGetUtc()
|
||||||
self.services.chat.updateWorkflow(workflow.id, {
|
self.services.chat.updateWorkflow(workflow.id, {
|
||||||
|
|
@ -296,6 +370,14 @@ class WorkflowManager:
|
||||||
"lastActivity": workflow.lastActivity
|
"lastActivity": workflow.lastActivity
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# Log response generation
|
||||||
|
self.services.chat.storeLog(workflow, {
|
||||||
|
"message": "Generating fast path response",
|
||||||
|
"type": "info",
|
||||||
|
"status": "running",
|
||||||
|
"progress": 0.8
|
||||||
|
})
|
||||||
|
|
||||||
# Create ChatMessage with fast path response (in user's language)
|
# Create ChatMessage with fast path response (in user's language)
|
||||||
messageData = {
|
messageData = {
|
||||||
"workflowId": workflow.id,
|
"workflowId": workflow.id,
|
||||||
|
|
@ -318,10 +400,25 @@ class WorkflowManager:
|
||||||
# Store message with documents
|
# Store message with documents
|
||||||
self.services.chat.storeMessageWithDocuments(workflow, messageData, chatDocuments)
|
self.services.chat.storeMessageWithDocuments(workflow, messageData, chatDocuments)
|
||||||
|
|
||||||
|
# Log fast path completion
|
||||||
|
self.services.chat.storeLog(workflow, {
|
||||||
|
"message": f"Fast path completed successfully (response length: {len(responseText)} chars)",
|
||||||
|
"type": "info",
|
||||||
|
"status": "completed",
|
||||||
|
"progress": 1.0
|
||||||
|
})
|
||||||
|
|
||||||
logger.info(f"Fast path completed successfully, response length: {len(responseText)} chars")
|
logger.info(f"Fast path completed successfully, response length: {len(responseText)} chars")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in _executeFastPath: {str(e)}")
|
logger.error(f"Error in _executeFastPath: {str(e)}")
|
||||||
|
# Log error
|
||||||
|
self.services.chat.storeLog(workflow, {
|
||||||
|
"message": f"Fast path error: {str(e)}. Falling back to full workflow.",
|
||||||
|
"type": "error",
|
||||||
|
"status": "running",
|
||||||
|
"progress": 0.5
|
||||||
|
})
|
||||||
# Fall back to full workflow on error
|
# Fall back to full workflow on error
|
||||||
logger.info("Falling back to full workflow due to fast path error")
|
logger.info("Falling back to full workflow due to fast path error")
|
||||||
taskPlan = await self._planTasks(userInput)
|
taskPlan = await self._planTasks(userInput)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue