PEK update
This commit is contained in:
parent
ad14e272f3
commit
12dd5aaea6
8 changed files with 400 additions and 40 deletions
|
|
@ -78,3 +78,28 @@ class PaginatedResponse(BaseModel, Generic[T]):
|
|||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
def normalize_pagination_dict(pagination_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Normalize pagination dictionary to handle frontend variations.
|
||||
Moves top-level "search" field into filters if present.
|
||||
|
||||
Args:
|
||||
pagination_dict: Raw pagination dictionary from frontend
|
||||
|
||||
Returns:
|
||||
Normalized pagination dictionary ready for PaginationParams parsing
|
||||
"""
|
||||
if not pagination_dict:
|
||||
return pagination_dict
|
||||
|
||||
# Create a copy to avoid modifying the original
|
||||
normalized = dict(pagination_dict)
|
||||
|
||||
# Move top-level "search" into filters if present
|
||||
if "search" in normalized:
|
||||
if "filters" not in normalized or normalized["filters"] is None:
|
||||
normalized["filters"] = {}
|
||||
normalized["filters"]["search"] = normalized.pop("search")
|
||||
|
||||
return normalized
|
||||
|
|
|
|||
|
|
@ -384,12 +384,14 @@ class ChatObjects:
|
|||
matches = True
|
||||
|
||||
# Handle general search across text fields
|
||||
if "search" in filters:
|
||||
search_term = str(filters["search"]).lower()
|
||||
if "search" in filters and filters["search"] is not None:
|
||||
search_term = str(filters["search"]).strip().lower()
|
||||
if search_term:
|
||||
# Search in all string fields
|
||||
found = False
|
||||
for key, value in record.items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, str) and search_term in value.lower():
|
||||
found = True
|
||||
break
|
||||
|
|
@ -404,17 +406,39 @@ class ChatObjects:
|
|||
if field_name == "search":
|
||||
continue # Already handled above
|
||||
|
||||
# Skip None or empty filter values (but empty string for strings should still filter)
|
||||
if filter_value is None:
|
||||
continue
|
||||
|
||||
# If field doesn't exist in record, reject this record for this filter
|
||||
if field_name not in record:
|
||||
matches = False
|
||||
break
|
||||
|
||||
record_value = record.get(field_name)
|
||||
|
||||
# Handle simple value (equals operator)
|
||||
# Handle simple value - default to "contains" for strings, "equals" for other types
|
||||
if not isinstance(filter_value, dict):
|
||||
if record_value != filter_value:
|
||||
# Skip None values in record
|
||||
if record_value is None:
|
||||
matches = False
|
||||
break
|
||||
|
||||
# For string fields, default to "contains" for more intuitive filtering
|
||||
if isinstance(record_value, str) and isinstance(filter_value, str):
|
||||
# Skip empty filter strings
|
||||
filter_str = str(filter_value).strip().lower()
|
||||
if not filter_str:
|
||||
continue
|
||||
record_str = record_value.lower()
|
||||
if filter_str not in record_str:
|
||||
matches = False
|
||||
break
|
||||
else:
|
||||
# For non-string fields, use exact match
|
||||
if record_value != filter_value:
|
||||
matches = False
|
||||
break
|
||||
continue
|
||||
|
||||
# Handle filter with operator
|
||||
|
|
|
|||
|
|
@ -308,12 +308,14 @@ class ComponentObjects:
|
|||
matches = True
|
||||
|
||||
# Handle general search across text fields
|
||||
if "search" in filters:
|
||||
search_term = str(filters["search"]).lower()
|
||||
if "search" in filters and filters["search"] is not None:
|
||||
search_term = str(filters["search"]).strip().lower()
|
||||
if search_term:
|
||||
# Search in all string fields
|
||||
found = False
|
||||
for key, value in record.items():
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, str) and search_term in value.lower():
|
||||
found = True
|
||||
break
|
||||
|
|
@ -322,23 +324,72 @@ class ComponentObjects:
|
|||
break
|
||||
if not found:
|
||||
matches = False
|
||||
# Continue checking other filters, but this record won't match
|
||||
|
||||
# Handle field-specific filters
|
||||
for field_name, filter_value in filters.items():
|
||||
if field_name == "search":
|
||||
continue # Already handled above
|
||||
|
||||
# Skip None or empty filter values (but empty string for strings should still filter)
|
||||
if filter_value is None:
|
||||
continue
|
||||
|
||||
# If field doesn't exist in record, reject this record for this filter
|
||||
if field_name not in record:
|
||||
matches = False
|
||||
break
|
||||
|
||||
record_value = record.get(field_name)
|
||||
|
||||
# Handle simple value (equals operator)
|
||||
# Handle simple value - default to "contains" for strings, "equals" for other types
|
||||
if not isinstance(filter_value, dict):
|
||||
if record_value != filter_value:
|
||||
# Skip None values in record
|
||||
if record_value is None:
|
||||
matches = False
|
||||
break
|
||||
|
||||
# Special handling for fileSize field - parse formatted size strings
|
||||
if field_name == "fileSize" and isinstance(filter_value, str):
|
||||
try:
|
||||
# Parse formatted size string (e.g., "2.13 MB", "1.5 GB", "500 KB")
|
||||
filter_size_bytes = self._parse_size_string(filter_value)
|
||||
if filter_size_bytes is not None:
|
||||
# Compare as integers (bytes)
|
||||
if isinstance(record_value, (int, float)):
|
||||
# Allow small tolerance for rounding differences (5% or 50KB, whichever is smaller)
|
||||
# This accounts for formatting differences but prevents matching very different sizes
|
||||
tolerance = min(filter_size_bytes * 0.05, 50 * 1024)
|
||||
if abs(record_value - filter_size_bytes) > tolerance:
|
||||
matches = False
|
||||
break
|
||||
else:
|
||||
matches = False
|
||||
break
|
||||
else:
|
||||
# If parsing fails, try string contains match
|
||||
if str(record_value) not in filter_value:
|
||||
matches = False
|
||||
break
|
||||
except Exception:
|
||||
# If parsing fails, skip this filter
|
||||
continue
|
||||
|
||||
# For string fields, default to "contains" for more intuitive filtering
|
||||
elif isinstance(record_value, str) and isinstance(filter_value, str):
|
||||
# Skip empty filter strings
|
||||
filter_str = str(filter_value).strip().lower()
|
||||
if not filter_str:
|
||||
continue
|
||||
record_str = record_value.lower()
|
||||
if filter_str not in record_str:
|
||||
matches = False
|
||||
break
|
||||
else:
|
||||
# For non-string fields, use exact match
|
||||
if record_value != filter_value:
|
||||
matches = False
|
||||
break
|
||||
continue
|
||||
|
||||
# Handle filter with operator
|
||||
|
|
@ -491,6 +542,49 @@ class ComponentObjects:
|
|||
def getInitialId(self, model_class: type) -> Optional[str]:
|
||||
"""Returns the initial ID for a table."""
|
||||
return self.db.getInitialId(model_class)
|
||||
|
||||
def _parse_size_string(self, size_str: str) -> Optional[int]:
|
||||
"""
|
||||
Parse a formatted size string (e.g., "2.13 MB", "1.5 GB") to bytes.
|
||||
|
||||
Args:
|
||||
size_str: Formatted size string like "2.13 MB", "1.5 GB", "500 KB"
|
||||
|
||||
Returns:
|
||||
Size in bytes, or None if parsing fails
|
||||
"""
|
||||
try:
|
||||
size_str = size_str.strip().upper()
|
||||
# Remove common separators and spaces
|
||||
size_str = size_str.replace(",", "").replace(" ", "")
|
||||
|
||||
# Extract number and unit - handle both "MB" and "M" formats
|
||||
import re
|
||||
# Match: number (with optional decimal) followed by optional unit (K/M/G/T with optional B)
|
||||
match = re.match(r"^([\d.]+)([KMGT]?B?)$", size_str)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
number = float(match.group(1))
|
||||
unit = match.group(2) or "B"
|
||||
|
||||
# Normalize unit (handle "M" as "MB", "K" as "KB", etc.)
|
||||
if len(unit) == 1 and unit in "KMGT":
|
||||
unit = unit + "B"
|
||||
|
||||
# Convert to bytes
|
||||
multipliers = {
|
||||
"B": 1,
|
||||
"KB": 1024,
|
||||
"MB": 1024 * 1024,
|
||||
"GB": 1024 * 1024 * 1024,
|
||||
"TB": 1024 * 1024 * 1024 * 1024,
|
||||
}
|
||||
|
||||
multiplier = multipliers.get(unit, 1)
|
||||
return int(number * multiplier)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,17 +8,20 @@ SECURITY NOTE:
|
|||
- This prevents security vulnerabilities where admin users could see other users' connections
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response, Query
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import status
|
||||
import logging
|
||||
import json
|
||||
import math
|
||||
|
||||
from modules.datamodels.datamodelUam import User, UserConnection, AuthAuthority, ConnectionStatus
|
||||
from modules.datamodels.datamodelSecurity import Token
|
||||
from modules.auth import getCurrentUser, limiter
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
from modules.interfaces.interfaceDbAppObjects import getInterface
|
||||
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
|
||||
from modules.interfaces.interfaceDbComponentObjects import ComponentObjects
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -87,20 +90,44 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
@router.get("/", response_model=List[UserConnection])
|
||||
@router.get("/", response_model=PaginatedResponse[UserConnection])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_connections(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[UserConnection]:
|
||||
"""Get all connections for the current user
|
||||
) -> PaginatedResponse[UserConnection]:
|
||||
"""Get connections for the current user with optional pagination, sorting, and filtering.
|
||||
|
||||
SECURITY: This endpoint is secure - users can only see their own connections.
|
||||
Automatically refreshes expired OAuth tokens in the background.
|
||||
|
||||
Query Parameters:
|
||||
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||||
|
||||
Examples:
|
||||
- GET /api/connections/ (no pagination - returns all items)
|
||||
- GET /api/connections/?pagination={"page":1,"pageSize":10,"sort":[]}
|
||||
- GET /api/connections/?pagination={"page":1,"pageSize":10,"filters":{"status":"active"}}
|
||||
"""
|
||||
try:
|
||||
interface = getInterface(currentUser)
|
||||
|
||||
# Parse pagination parameter
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
# Normalize pagination dict (handles top-level "search" field)
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid pagination parameter: {str(e)}"
|
||||
)
|
||||
|
||||
# SECURITY FIX: All users (including admins) can only see their own connections
|
||||
# This prevents admin from seeing other users' connections and causing confusion
|
||||
connections = interface.getUserConnections(currentUser.id)
|
||||
|
|
@ -117,33 +144,111 @@ async def get_connections(
|
|||
logger.warning(f"Silent token refresh failed for user {currentUser.id}: {str(e)}")
|
||||
# Continue with original connections even if refresh fails
|
||||
|
||||
# Enhance each connection with token status information
|
||||
enhanced_connections = []
|
||||
# Enhance each connection with token status information and convert to dict
|
||||
enhanced_connections_dict = []
|
||||
for connection in connections:
|
||||
# Get token status for this connection
|
||||
tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connection.id)
|
||||
|
||||
# Create enhanced connection with token status
|
||||
enhanced_connection = UserConnection(
|
||||
id=connection.id,
|
||||
userId=connection.userId,
|
||||
authority=connection.authority,
|
||||
externalId=connection.externalId,
|
||||
externalUsername=connection.externalUsername,
|
||||
externalEmail=connection.externalEmail,
|
||||
status=connection.status,
|
||||
connectedAt=connection.connectedAt,
|
||||
lastChecked=connection.lastChecked,
|
||||
expiresAt=connection.expiresAt,
|
||||
tokenStatus=tokenStatus,
|
||||
tokenExpiresAt=tokenExpiresAt
|
||||
# Convert to dict for filtering/sorting
|
||||
connection_dict = {
|
||||
"id": connection.id,
|
||||
"userId": connection.userId,
|
||||
"authority": connection.authority.value if hasattr(connection.authority, 'value') else str(connection.authority),
|
||||
"externalId": connection.externalId,
|
||||
"externalUsername": connection.externalUsername or "",
|
||||
"externalEmail": connection.externalEmail, # Keep None instead of converting to empty string
|
||||
"status": connection.status.value if hasattr(connection.status, 'value') else str(connection.status),
|
||||
"connectedAt": connection.connectedAt,
|
||||
"lastChecked": connection.lastChecked,
|
||||
"expiresAt": connection.expiresAt,
|
||||
"tokenStatus": tokenStatus,
|
||||
"tokenExpiresAt": tokenExpiresAt
|
||||
}
|
||||
enhanced_connections_dict.append(connection_dict)
|
||||
|
||||
# If no pagination requested, return all items
|
||||
if paginationParams is None:
|
||||
# Convert back to UserConnection objects (enum strings are already in dict)
|
||||
items = []
|
||||
for conn_dict in enhanced_connections_dict:
|
||||
conn_dict_copy = dict(conn_dict)
|
||||
if "authority" in conn_dict_copy and isinstance(conn_dict_copy["authority"], str):
|
||||
try:
|
||||
conn_dict_copy["authority"] = AuthAuthority(conn_dict_copy["authority"])
|
||||
except ValueError:
|
||||
pass
|
||||
if "status" in conn_dict_copy and isinstance(conn_dict_copy["status"], str):
|
||||
try:
|
||||
conn_dict_copy["status"] = ConnectionStatus(conn_dict_copy["status"])
|
||||
except ValueError:
|
||||
pass
|
||||
items.append(UserConnection(**conn_dict_copy))
|
||||
return PaginatedResponse(
|
||||
items=items,
|
||||
pagination=None
|
||||
)
|
||||
enhanced_connections.append(enhanced_connection)
|
||||
|
||||
return enhanced_connections
|
||||
# Apply filtering if provided
|
||||
if paginationParams.filters:
|
||||
component_interface = ComponentObjects()
|
||||
component_interface.setUserContext(currentUser)
|
||||
enhanced_connections_dict = component_interface._applyFilters(
|
||||
enhanced_connections_dict,
|
||||
paginationParams.filters
|
||||
)
|
||||
|
||||
# Apply sorting if provided
|
||||
if paginationParams.sort:
|
||||
component_interface = ComponentObjects()
|
||||
component_interface.setUserContext(currentUser)
|
||||
enhanced_connections_dict = component_interface._applySorting(
|
||||
enhanced_connections_dict,
|
||||
paginationParams.sort
|
||||
)
|
||||
|
||||
# Count total items after filters
|
||||
totalItems = len(enhanced_connections_dict)
|
||||
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
||||
|
||||
# Apply pagination (skip/limit)
|
||||
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
||||
endIdx = startIdx + paginationParams.pageSize
|
||||
paged_connections = enhanced_connections_dict[startIdx:endIdx]
|
||||
|
||||
# Convert back to UserConnection objects (convert enum strings back to enums)
|
||||
items = []
|
||||
for conn_dict in paged_connections:
|
||||
# Convert enum strings back to enum objects
|
||||
conn_dict_copy = dict(conn_dict)
|
||||
if "authority" in conn_dict_copy and isinstance(conn_dict_copy["authority"], str):
|
||||
try:
|
||||
conn_dict_copy["authority"] = AuthAuthority(conn_dict_copy["authority"])
|
||||
except ValueError:
|
||||
pass # Keep as string if invalid
|
||||
if "status" in conn_dict_copy and isinstance(conn_dict_copy["status"], str):
|
||||
try:
|
||||
conn_dict_copy["status"] = ConnectionStatus(conn_dict_copy["status"])
|
||||
except ValueError:
|
||||
pass # Keep as string if invalid
|
||||
items.append(UserConnection(**conn_dict_copy))
|
||||
|
||||
return PaginatedResponse(
|
||||
items=items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting connections: {str(e)}")
|
||||
logger.error(f"Error getting connections: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get connections: {str(e)}"
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import modules.interfaces.interfaceDbComponentObjects as interfaceDbComponentObj
|
|||
from modules.datamodels.datamodelFiles import FileItem, FilePreview
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -57,7 +57,10 @@ async def get_files(
|
|||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
if paginationDict:
|
||||
# Normalize pagination dict (handles top-level "search" field)
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from modules.auth import limiter, getCurrentUser
|
|||
import modules.interfaces.interfaceDbComponentObjects as interfaceDbComponentObjects
|
||||
from modules.datamodels.datamodelUtils import Prompt
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -46,7 +46,10 @@ async def get_prompts(
|
|||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
if paginationDict:
|
||||
# Normalize pagination dict (handles top-level "search" field)
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from modules.datamodels.datamodelChat import (
|
|||
)
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions, AttributeResponse
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
|
||||
|
||||
# Configure logger
|
||||
|
|
@ -69,7 +69,10 @@ async def get_workflows(
|
|||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
if paginationDict:
|
||||
# Normalize pagination dict (handles top-level "search" field)
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
|
@ -262,7 +265,10 @@ async def get_workflow_logs(
|
|||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
if paginationDict:
|
||||
# Normalize pagination dict (handles top-level "search" field)
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
|
@ -350,7 +356,10 @@ async def get_workflow_messages(
|
|||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
if paginationDict:
|
||||
# Normalize pagination dict (handles top-level "search" field)
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
|
|
|||
|
|
@ -227,11 +227,28 @@ class WorkflowManager:
|
|||
workflow = self.services.workflow
|
||||
checkWorkflowStopped(self.services)
|
||||
|
||||
# Log fast path start
|
||||
self.services.chat.storeLog(workflow, {
|
||||
"message": "Fast path execution started",
|
||||
"type": "info",
|
||||
"status": "running",
|
||||
"progress": 0.1
|
||||
})
|
||||
|
||||
# Get user language if available
|
||||
userLanguage = getattr(self.services, 'currentUserLanguage', None)
|
||||
|
||||
# Execute fast path - use normalizedRequest if available, otherwise use raw prompt
|
||||
normalizedPrompt = getattr(self.services, 'currentUserPromptNormalized', None) or userInput.prompt
|
||||
|
||||
# Log fast path execution
|
||||
self.services.chat.storeLog(workflow, {
|
||||
"message": f"Processing request via fast path (language: {userLanguage or 'auto'})",
|
||||
"type": "info",
|
||||
"status": "running",
|
||||
"progress": 0.3
|
||||
})
|
||||
|
||||
result = await self.workflowProcessor.fastPathExecute(
|
||||
prompt=normalizedPrompt,
|
||||
documents=documents,
|
||||
|
|
@ -241,6 +258,12 @@ class WorkflowManager:
|
|||
if not result.success:
|
||||
# Fast path failed, fall back to full workflow
|
||||
logger.warning(f"Fast path failed: {result.error}, falling back to full workflow")
|
||||
self.services.chat.storeLog(workflow, {
|
||||
"message": f"Fast path failed: {result.error}. Falling back to full workflow.",
|
||||
"type": "warning",
|
||||
"status": "running",
|
||||
"progress": 0.5
|
||||
})
|
||||
taskPlan = await self._planTasks(userInput)
|
||||
await self._executeTasks(taskPlan)
|
||||
await self._processWorkflowResults()
|
||||
|
|
@ -288,7 +311,58 @@ class WorkflowManager:
|
|||
}
|
||||
chatDocuments.append(chatDoc)
|
||||
|
||||
# Mark workflow as completed BEFORE storing message (so UI polling stops)
|
||||
# Create initial user message first
|
||||
roundNum = workflow.currentRound
|
||||
contextLabel = f"round{roundNum}_usercontext"
|
||||
|
||||
userMessageData = {
|
||||
"workflowId": workflow.id,
|
||||
"role": "user",
|
||||
"message": userInput.prompt,
|
||||
"status": "first",
|
||||
"sequenceNr": 1,
|
||||
"publishedAt": self.services.utils.timestampGetUtc(),
|
||||
"documentsLabel": contextLabel,
|
||||
"documents": [],
|
||||
# Add workflow context fields
|
||||
"roundNumber": workflow.currentRound,
|
||||
"taskNumber": 0,
|
||||
"actionNumber": 0,
|
||||
# Add progress status
|
||||
"taskProgress": "pending",
|
||||
"actionProgress": "pending"
|
||||
}
|
||||
|
||||
# Store user message (with any uploaded documents)
|
||||
# Convert ChatDocument objects to dicts for storeMessageWithDocuments
|
||||
userDocuments = []
|
||||
for doc in documents:
|
||||
if isinstance(doc, dict):
|
||||
userDoc = dict(doc)
|
||||
else:
|
||||
# ChatDocument object - convert to dict
|
||||
userDoc = {
|
||||
"fileId": doc.fileId,
|
||||
"fileName": doc.fileName,
|
||||
"fileSize": doc.fileSize,
|
||||
"mimeType": doc.mimeType,
|
||||
"roundNumber": workflow.currentRound,
|
||||
"taskNumber": 0,
|
||||
"actionNumber": 0
|
||||
}
|
||||
userDocuments.append(userDoc)
|
||||
|
||||
self.services.chat.storeMessageWithDocuments(workflow, userMessageData, userDocuments)
|
||||
|
||||
# Log user message stored
|
||||
self.services.chat.storeLog(workflow, {
|
||||
"message": "User message stored",
|
||||
"type": "info",
|
||||
"status": "running",
|
||||
"progress": 0.6
|
||||
})
|
||||
|
||||
# Mark workflow as completed BEFORE storing response message (so UI polling stops)
|
||||
workflow.status = "completed"
|
||||
workflow.lastActivity = self.services.utils.timestampGetUtc()
|
||||
self.services.chat.updateWorkflow(workflow.id, {
|
||||
|
|
@ -296,6 +370,14 @@ class WorkflowManager:
|
|||
"lastActivity": workflow.lastActivity
|
||||
})
|
||||
|
||||
# Log response generation
|
||||
self.services.chat.storeLog(workflow, {
|
||||
"message": "Generating fast path response",
|
||||
"type": "info",
|
||||
"status": "running",
|
||||
"progress": 0.8
|
||||
})
|
||||
|
||||
# Create ChatMessage with fast path response (in user's language)
|
||||
messageData = {
|
||||
"workflowId": workflow.id,
|
||||
|
|
@ -318,10 +400,25 @@ class WorkflowManager:
|
|||
# Store message with documents
|
||||
self.services.chat.storeMessageWithDocuments(workflow, messageData, chatDocuments)
|
||||
|
||||
# Log fast path completion
|
||||
self.services.chat.storeLog(workflow, {
|
||||
"message": f"Fast path completed successfully (response length: {len(responseText)} chars)",
|
||||
"type": "info",
|
||||
"status": "completed",
|
||||
"progress": 1.0
|
||||
})
|
||||
|
||||
logger.info(f"Fast path completed successfully, response length: {len(responseText)} chars")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _executeFastPath: {str(e)}")
|
||||
# Log error
|
||||
self.services.chat.storeLog(workflow, {
|
||||
"message": f"Fast path error: {str(e)}. Falling back to full workflow.",
|
||||
"type": "error",
|
||||
"status": "running",
|
||||
"progress": 0.5
|
||||
})
|
||||
# Fall back to full workflow on error
|
||||
logger.info("Falling back to full workflow due to fast path error")
|
||||
taskPlan = await self._planTasks(userInput)
|
||||
|
|
|
|||
Loading…
Reference in a new issue