fixed stats and billing sync
This commit is contained in:
parent
bbea0ff115
commit
e4d41965f3
9 changed files with 151 additions and 47 deletions
|
|
@ -144,6 +144,9 @@ class AiCallOptions(BaseModel):
|
||||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0, description="Temperature for response generation (0.0-2.0, lower = more consistent)")
|
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0, description="Temperature for response generation (0.0-2.0, lower = more consistent)")
|
||||||
maxParts: Optional[int] = Field(default=1000, ge=1, le=1000, description="Maximum number of continuation parts to fetch")
|
maxParts: Optional[int] = Field(default=1000, ge=1, le=1000, description="Maximum number of continuation parts to fetch")
|
||||||
|
|
||||||
|
# Provider filtering (from UI multiselect or automation config)
|
||||||
|
allowedProviders: Optional[List[str]] = Field(default=None, description="List of allowed AI providers to use (empty = all RBAC-permitted)")
|
||||||
|
|
||||||
|
|
||||||
class AiCallRequest(BaseModel):
|
class AiCallRequest(BaseModel):
|
||||||
"""Centralized AI call request payload for interface use."""
|
"""Centralized AI call request payload for interface use."""
|
||||||
|
|
|
||||||
|
|
@ -403,8 +403,7 @@ class UserInputRequest(BaseModel):
|
||||||
listFileId: List[str] = Field(default_factory=list, description="List of file IDs")
|
listFileId: List[str] = Field(default_factory=list, description="List of file IDs")
|
||||||
userLanguage: str = Field(default="en", description="User's preferred language")
|
userLanguage: str = Field(default="en", description="User's preferred language")
|
||||||
workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue")
|
workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue")
|
||||||
preferredProvider: Optional[str] = Field(None, description="Preferred AI provider (e.g., 'anthropic', 'openai') - deprecated, use preferredProviders")
|
allowedProviders: Optional[List[str]] = Field(None, description="List of allowed AI providers (multiselect)")
|
||||||
preferredProviders: Optional[List[str]] = Field(None, description="List of preferred AI providers (multiselect)")
|
|
||||||
|
|
||||||
|
|
||||||
registerModelLabels(
|
registerModelLabels(
|
||||||
|
|
|
||||||
|
|
@ -89,6 +89,17 @@ class AiObjects:
|
||||||
|
|
||||||
# Get failover models for this operation type
|
# Get failover models for this operation type
|
||||||
availableModels = modelRegistry.getAvailableModels()
|
availableModels = modelRegistry.getAvailableModels()
|
||||||
|
|
||||||
|
# Filter by allowedProviders if specified (from workflow config)
|
||||||
|
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
||||||
|
if allowedProviders:
|
||||||
|
filteredModels = [m for m in availableModels if m.connectorType in allowedProviders]
|
||||||
|
if filteredModels:
|
||||||
|
logger.info(f"Filtered models by allowedProviders {allowedProviders}: {len(filteredModels)} models (from {len(availableModels)})")
|
||||||
|
availableModels = filteredModels
|
||||||
|
else:
|
||||||
|
logger.warning(f"No models match allowedProviders {allowedProviders}, using all {len(availableModels)} available models")
|
||||||
|
|
||||||
failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels)
|
failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels)
|
||||||
|
|
||||||
if not failoverModelList:
|
if not failoverModelList:
|
||||||
|
|
|
||||||
|
|
@ -1532,8 +1532,19 @@ class ChatObjects:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Return all stats records sorted by creation time
|
# Return all stats records sorted by creation time
|
||||||
stats.sort(key=lambda x: x.get("created_at", ""))
|
# DB uses _createdAt (camelCase system field)
|
||||||
return [ChatStat(**stat) for stat in stats]
|
stats.sort(key=lambda x: x.get("_createdAt", 0))
|
||||||
|
|
||||||
|
# Convert to ChatStat objects, preserving _createdAt via extra="allow"
|
||||||
|
result = []
|
||||||
|
for stat in stats:
|
||||||
|
chat_stat = ChatStat(**stat)
|
||||||
|
# Explicitly preserve _createdAt from raw DB record
|
||||||
|
if "_createdAt" in stat:
|
||||||
|
setattr(chat_stat, '_createdAt', stat["_createdAt"])
|
||||||
|
result.append(chat_stat)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def createStat(self, statData: Dict[str, Any]) -> ChatStat:
|
def createStat(self, statData: Dict[str, Any]) -> ChatStat:
|
||||||
|
|
@ -1549,9 +1560,16 @@ class ChatObjects:
|
||||||
# Validate the stat data against ChatStat model
|
# Validate the stat data against ChatStat model
|
||||||
stat = ChatStat(**statData)
|
stat = ChatStat(**statData)
|
||||||
|
|
||||||
|
logger.debug(f"Creating stat for workflow {statData.get('workflowId')}: "
|
||||||
|
f"process={statData.get('process')}, "
|
||||||
|
f"priceCHF={statData.get('priceCHF', 0):.4f}, "
|
||||||
|
f"processingTime={statData.get('processingTime', 0):.2f}s")
|
||||||
|
|
||||||
# Create the stat record in the database
|
# Create the stat record in the database
|
||||||
created = self.db.recordCreate(ChatStat, stat)
|
created = self.db.recordCreate(ChatStat, stat)
|
||||||
|
|
||||||
|
logger.info(f"Created stat {created.get('id')} for workflow {statData.get('workflowId')}")
|
||||||
|
|
||||||
# Return the created ChatStat
|
# Return the created ChatStat
|
||||||
return ChatStat(**created)
|
return ChatStat(**created)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1643,10 +1661,14 @@ class ChatObjects:
|
||||||
if afterTimestamp is not None and stat_timestamp <= afterTimestamp:
|
if afterTimestamp is not None and stat_timestamp <= afterTimestamp:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Convert to dict and include _createdAt for frontend
|
||||||
|
stat_dict = stat.model_dump() if hasattr(stat, 'model_dump') else stat.dict()
|
||||||
|
stat_dict['_createdAt'] = stat_timestamp
|
||||||
|
|
||||||
items.append({
|
items.append({
|
||||||
"type": "stat",
|
"type": "stat",
|
||||||
"createdAt": stat_timestamp,
|
"createdAt": stat_timestamp,
|
||||||
"item": stat
|
"item": stat_dict
|
||||||
})
|
})
|
||||||
|
|
||||||
# Sort all items by createdAt timestamp for chronological order
|
# Sort all items by createdAt timestamp for chronological order
|
||||||
|
|
|
||||||
|
|
@ -94,15 +94,30 @@ class AiService:
|
||||||
Includes billing checks:
|
Includes billing checks:
|
||||||
- Balance check before AI call
|
- Balance check before AI call
|
||||||
- Provider permission check (via RBAC)
|
- Provider permission check (via RBAC)
|
||||||
|
|
||||||
|
Also stores workflow stats after each successful AI call.
|
||||||
"""
|
"""
|
||||||
# Billing check before AI call
|
# Billing check before AI call (validates RBAC permissions)
|
||||||
await self._checkBillingBeforeAiCall()
|
await self._checkBillingBeforeAiCall()
|
||||||
|
|
||||||
|
# Calculate effective allowedProviders: RBAC ∩ Workflow
|
||||||
|
# RBAC is master - only RBAC-permitted providers can ever be used
|
||||||
|
effectiveProviders = self._calculateEffectiveProviders()
|
||||||
|
if effectiveProviders and request.options:
|
||||||
|
request.options = request.options.model_copy(update={'allowedProviders': effectiveProviders})
|
||||||
|
logger.debug(f"Effective allowedProviders for AI request: {effectiveProviders}")
|
||||||
|
|
||||||
if hasattr(request, 'contentParts') and request.contentParts:
|
if hasattr(request, 'contentParts') and request.contentParts:
|
||||||
return await self.extractionService.processContentPartsWithAi(
|
response = await self.extractionService.processContentPartsWithAi(
|
||||||
request, self.aiObjects, progressCallback
|
request, self.aiObjects, progressCallback
|
||||||
)
|
)
|
||||||
return await self.aiObjects.callWithTextContext(request)
|
else:
|
||||||
|
response = await self.aiObjects.callWithTextContext(request)
|
||||||
|
|
||||||
|
# Store workflow stats after each AI call
|
||||||
|
self._storeAiCallStats(response, request)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
async def _checkBillingBeforeAiCall(self) -> None:
|
async def _checkBillingBeforeAiCall(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -206,6 +221,87 @@ class AiService:
|
||||||
# Log but don't block on billing check errors
|
# Log but don't block on billing check errors
|
||||||
logger.warning(f"Billing check failed with error (non-blocking): {e}")
|
logger.warning(f"Billing check failed with error (non-blocking): {e}")
|
||||||
|
|
||||||
|
def _calculateEffectiveProviders(self) -> Optional[List[str]]:
|
||||||
|
"""
|
||||||
|
Calculate effective allowed providers: RBAC ∩ Workflow.
|
||||||
|
|
||||||
|
RBAC is master - only RBAC-permitted providers can ever be used.
|
||||||
|
If workflow specifies allowedProviders, intersect with RBAC.
|
||||||
|
If no workflow providers, use all RBAC-permitted providers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of effective allowed providers, or None if no filtering needed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user = getattr(self.services, 'user', None)
|
||||||
|
mandateId = getattr(self.services, 'mandateId', None)
|
||||||
|
|
||||||
|
if not user or not mandateId:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get RBAC-permitted providers (master list)
|
||||||
|
# Note: getBillingService is imported at module level from mainServiceBilling
|
||||||
|
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||||
|
featureCode = getattr(self.services, 'featureCode', None)
|
||||||
|
billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
|
||||||
|
rbacProviders = billingService.getallowedProviders()
|
||||||
|
|
||||||
|
if not rbacProviders:
|
||||||
|
logger.warning("No RBAC-permitted providers found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get workflow-specified providers (optional filter)
|
||||||
|
workflowProviders = getattr(self.services, 'allowedProviders', None)
|
||||||
|
|
||||||
|
if workflowProviders:
|
||||||
|
# Intersect: only providers that are both RBAC-permitted AND workflow-allowed
|
||||||
|
effectiveProviders = [p for p in workflowProviders if p in rbacProviders]
|
||||||
|
logger.debug(f"Provider filter: RBAC={rbacProviders}, Workflow={workflowProviders}, Effective={effectiveProviders}")
|
||||||
|
else:
|
||||||
|
# No workflow filter - use all RBAC-permitted providers
|
||||||
|
effectiveProviders = rbacProviders
|
||||||
|
logger.debug(f"Provider filter: RBAC={rbacProviders} (no workflow filter)")
|
||||||
|
|
||||||
|
return effectiveProviders if effectiveProviders else None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error calculating effective providers: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _storeAiCallStats(self, response, request: AiCallRequest) -> None:
|
||||||
|
"""Store workflow stats after an AI call.
|
||||||
|
|
||||||
|
This method stores the AI call statistics (cost, processing time, bytes)
|
||||||
|
to the workflow stats collection for tracking and billing purposes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: AiCallResponse with cost/timing data
|
||||||
|
request: Original AiCallRequest for context
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Skip if no workflow context
|
||||||
|
workflow = getattr(self.services, 'workflow', None)
|
||||||
|
if not workflow or not hasattr(workflow, 'id') or not workflow.id:
|
||||||
|
logger.debug("No workflow context - skipping stats storage")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip if response is an error
|
||||||
|
if not response or getattr(response, 'errorCount', 0) > 0:
|
||||||
|
logger.debug("Error response - skipping stats storage")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine process name from operation type
|
||||||
|
opType = getattr(request.options, 'operationType', 'unknown') if request.options else 'unknown'
|
||||||
|
process = f"ai.call.{opType}"
|
||||||
|
|
||||||
|
# Store the stat
|
||||||
|
self.services.chat.storeWorkflowStat(workflow, response, process)
|
||||||
|
logger.debug(f"Stored AI call stat: {process}, cost={getattr(response, 'priceCHF', 0):.4f} CHF")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't fail - stats storage is not critical
|
||||||
|
logger.debug(f"Could not store AI call stat: {str(e)}")
|
||||||
|
|
||||||
async def ensureAiObjectsInitialized(self):
|
async def ensureAiObjectsInitialized(self):
|
||||||
"""Ensure aiObjects is initialized and submodules are ready."""
|
"""Ensure aiObjects is initialized and submodules are ready."""
|
||||||
if self.aiObjects is None:
|
if self.aiObjects is None:
|
||||||
|
|
@ -428,7 +524,7 @@ Respond with ONLY a JSON object in this exact format:
|
||||||
# Debug: persist prompt/response for analysis with context-specific naming
|
# Debug: persist prompt/response for analysis with context-specific naming
|
||||||
debugPrefix = debugType if debugType else "plan"
|
debugPrefix = debugType if debugType else "plan"
|
||||||
self.services.utils.writeDebugFile(fullPrompt, f"{debugPrefix}_prompt")
|
self.services.utils.writeDebugFile(fullPrompt, f"{debugPrefix}_prompt")
|
||||||
response = await self.aiObjects.callWithTextContext(request)
|
response = await self.callAi(request) # Use callAi to ensure stats are stored
|
||||||
result = response.content or ""
|
result = response.content or ""
|
||||||
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response")
|
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response")
|
||||||
return result
|
return result
|
||||||
|
|
@ -485,16 +581,7 @@ Respond with ONLY a JSON object in this exact format:
|
||||||
operationType=opType.value
|
operationType=opType.value
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to store workflow stats, but don't fail if workflow is None (e.g., in chatbot context)
|
# Note: Stats are now stored centrally in callAi() - no need to duplicate here
|
||||||
try:
|
|
||||||
self.services.chat.storeWorkflowStat(
|
|
||||||
self.services.workflow,
|
|
||||||
response,
|
|
||||||
f"ai.{opType.name.lower()}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# Log but don't fail - workflow might be None in some contexts (e.g., chatbot)
|
|
||||||
logger.debug(f"Could not store workflow stat (workflow may be None): {str(e)}")
|
|
||||||
|
|
||||||
self.services.chat.progressLogUpdate(aiOperationId, 0.9, f"{opType.name} completed")
|
self.services.chat.progressLogUpdate(aiOperationId, 0.9, f"{opType.name} completed")
|
||||||
self.services.chat.progressLogFinish(aiOperationId, True)
|
self.services.chat.progressLogFinish(aiOperationId, True)
|
||||||
|
|
|
||||||
|
|
@ -269,17 +269,7 @@ class AiCallLooper:
|
||||||
# Document generation - save all iteration responses
|
# Document generation - save all iteration responses
|
||||||
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response_iteration_{iteration}")
|
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response_iteration_{iteration}")
|
||||||
|
|
||||||
# Emit stats for this iteration (only if workflow exists and has id)
|
# Note: Stats are now stored centrally in callAi() - no need to duplicate here
|
||||||
if self.services.workflow and hasattr(self.services.workflow, 'id') and self.services.workflow.id:
|
|
||||||
try:
|
|
||||||
self.services.chat.storeWorkflowStat(
|
|
||||||
self.services.workflow,
|
|
||||||
response,
|
|
||||||
f"ai.call.{debugPrefix}.iteration_{iteration}"
|
|
||||||
)
|
|
||||||
except Exception as statError:
|
|
||||||
# Don't break the main loop if stat storage fails
|
|
||||||
logger.warning(f"Failed to store workflow stat: {str(statError)}")
|
|
||||||
|
|
||||||
# Check for error response using generic error detection (errorCount > 0 or modelName == "error")
|
# Check for error response using generic error detection (errorCount > 0 or modelName == "error")
|
||||||
if hasattr(response, 'errorCount') and response.errorCount > 0:
|
if hasattr(response, 'errorCount') and response.errorCount > 0:
|
||||||
|
|
|
||||||
|
|
@ -101,11 +101,7 @@ class ImageGenerationPath:
|
||||||
operationType=OperationTypeEnum.IMAGE_GENERATE.value
|
operationType=OperationTypeEnum.IMAGE_GENERATE.value
|
||||||
)
|
)
|
||||||
|
|
||||||
self.services.chat.storeWorkflowStat(
|
# Note: Stats are now stored centrally in callAi() - no need to duplicate here
|
||||||
self.services.workflow,
|
|
||||||
response,
|
|
||||||
"ai.generate.image"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.services.chat.progressLogUpdate(imageOperationId, 0.9, "Image generated")
|
self.services.chat.progressLogUpdate(imageOperationId, 0.9, "Image generated")
|
||||||
self.services.chat.progressLogFinish(imageOperationId, True)
|
self.services.chat.progressLogFinish(imageOperationId, True)
|
||||||
|
|
|
||||||
|
|
@ -42,14 +42,10 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
|
||||||
try:
|
try:
|
||||||
services = getServices(currentUser, mandateId=mandateId)
|
services = getServices(currentUser, mandateId=mandateId)
|
||||||
|
|
||||||
# Store preferred providers in services context for billing/model selection
|
# Store allowedProviders in services context for model selection
|
||||||
# Support both preferredProviders (list) and legacy preferredProvider (string)
|
if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders:
|
||||||
if hasattr(userInput, 'preferredProviders') and userInput.preferredProviders:
|
services.allowedProviders = userInput.allowedProviders
|
||||||
services.preferredProviders = userInput.preferredProviders
|
logger.info(f"AI provider filter active: {userInput.allowedProviders}")
|
||||||
logger.debug(f"Using preferred providers: {userInput.preferredProviders}")
|
|
||||||
elif hasattr(userInput, 'preferredProvider') and userInput.preferredProvider:
|
|
||||||
services.preferredProviders = [userInput.preferredProvider]
|
|
||||||
logger.debug(f"Using preferred provider (legacy): {userInput.preferredProvider}")
|
|
||||||
|
|
||||||
# Store feature instance ID in services context
|
# Store feature instance ID in services context
|
||||||
if featureInstanceId:
|
if featureInstanceId:
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,8 @@ from modules.workflows.processing.modes.modeDynamic import DynamicMode
|
||||||
from modules.workflows.processing.modes.modeAutomation import AutomationMode
|
from modules.workflows.processing.modes.modeAutomation import AutomationMode
|
||||||
from modules.workflows.processing.shared.stateTools import checkWorkflowStopped
|
from modules.workflows.processing.shared.stateTools import checkWorkflowStopped
|
||||||
from modules.datamodels.datamodelAi import OperationTypeEnum, PriorityEnum, ProcessingModeEnum, AiCallOptions, AiCallRequest
|
from modules.datamodels.datamodelAi import OperationTypeEnum, PriorityEnum, ProcessingModeEnum, AiCallOptions, AiCallRequest
|
||||||
from modules.shared.jsonUtils import extractJsonString, repairBrokenJson
|
from modules.shared.jsonUtils import extractJsonString, repairBrokenJson, parseJsonWithModel
|
||||||
|
from modules.datamodels.datamodelWorkflow import UnderstandingResult
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from modules.datamodels.datamodelWorkflow import TaskResult
|
from modules.datamodels.datamodelWorkflow import TaskResult
|
||||||
|
|
@ -477,8 +478,7 @@ class WorkflowProcessor:
|
||||||
maxProcessingTime=15 # Fast path should complete in 15s
|
maxProcessingTime=15 # Fast path should complete in 15s
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call AI directly (no document generation - just plain text response)
|
# Call AI via callAi() to ensure stats are stored
|
||||||
# Use callWithTextContext() for text-only calls
|
|
||||||
aiRequest = AiCallRequest(
|
aiRequest = AiCallRequest(
|
||||||
prompt=fastPathPrompt,
|
prompt=fastPathPrompt,
|
||||||
context="",
|
context="",
|
||||||
|
|
@ -486,7 +486,7 @@ class WorkflowProcessor:
|
||||||
contentParts=None # Fast path doesn't process documents
|
contentParts=None # Fast path doesn't process documents
|
||||||
)
|
)
|
||||||
|
|
||||||
aiCallResponse = await self.services.ai.aiObjects.callWithTextContext(aiRequest)
|
aiCallResponse = await self.services.ai.callAi(aiRequest)
|
||||||
|
|
||||||
# Extract response content (AiCallResponse.content is a string)
|
# Extract response content (AiCallResponse.content is a string)
|
||||||
responseText = aiCallResponse.content if aiCallResponse.content else ""
|
responseText = aiCallResponse.content if aiCallResponse.content else ""
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue