fixed stats and billing sync

This commit is contained in:
patrick-motsch 2026-02-08 01:44:43 +01:00
parent bbea0ff115
commit e4d41965f3
9 changed files with 151 additions and 47 deletions

View file

@ -144,6 +144,9 @@ class AiCallOptions(BaseModel):
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0, description="Temperature for response generation (0.0-2.0, lower = more consistent)")
maxParts: Optional[int] = Field(default=1000, ge=1, le=1000, description="Maximum number of continuation parts to fetch")
# Provider filtering (from UI multiselect or automation config)
allowedProviders: Optional[List[str]] = Field(default=None, description="List of allowed AI providers to use (empty = all RBAC-permitted)")
class AiCallRequest(BaseModel):
"""Centralized AI call request payload for interface use."""

View file

@ -403,8 +403,7 @@ class UserInputRequest(BaseModel):
listFileId: List[str] = Field(default_factory=list, description="List of file IDs")
userLanguage: str = Field(default="en", description="User's preferred language")
workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue")
preferredProvider: Optional[str] = Field(None, description="Preferred AI provider (e.g., 'anthropic', 'openai') - deprecated, use preferredProviders")
preferredProviders: Optional[List[str]] = Field(None, description="List of preferred AI providers (multiselect)")
allowedProviders: Optional[List[str]] = Field(None, description="List of allowed AI providers (multiselect)")
registerModelLabels(

View file

@ -89,6 +89,17 @@ class AiObjects:
# Get failover models for this operation type
availableModels = modelRegistry.getAvailableModels()
# Filter by allowedProviders if specified (from workflow config)
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
if allowedProviders:
filteredModels = [m for m in availableModels if m.connectorType in allowedProviders]
if filteredModels:
logger.info(f"Filtered models by allowedProviders {allowedProviders}: {len(filteredModels)} models (from {len(availableModels)})")
availableModels = filteredModels
else:
logger.warning(f"No models match allowedProviders {allowedProviders}, using all {len(availableModels)} available models")
failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels)
if not failoverModelList:

View file

@ -1532,8 +1532,19 @@ class ChatObjects:
return []
# Return all stats records sorted by creation time
stats.sort(key=lambda x: x.get("created_at", ""))
return [ChatStat(**stat) for stat in stats]
# DB uses _createdAt (camelCase system field)
stats.sort(key=lambda x: x.get("_createdAt", 0))
# Convert to ChatStat objects, preserving _createdAt via extra="allow"
result = []
for stat in stats:
chat_stat = ChatStat(**stat)
# Explicitly preserve _createdAt from raw DB record
if "_createdAt" in stat:
setattr(chat_stat, '_createdAt', stat["_createdAt"])
result.append(chat_stat)
return result
def createStat(self, statData: Dict[str, Any]) -> ChatStat:
@ -1549,9 +1560,16 @@ class ChatObjects:
# Validate the stat data against ChatStat model
stat = ChatStat(**statData)
logger.debug(f"Creating stat for workflow {statData.get('workflowId')}: "
f"process={statData.get('process')}, "
f"priceCHF={statData.get('priceCHF', 0):.4f}, "
f"processingTime={statData.get('processingTime', 0):.2f}s")
# Create the stat record in the database
created = self.db.recordCreate(ChatStat, stat)
logger.info(f"Created stat {created.get('id')} for workflow {statData.get('workflowId')}")
# Return the created ChatStat
return ChatStat(**created)
except Exception as e:
@ -1643,10 +1661,14 @@ class ChatObjects:
if afterTimestamp is not None and stat_timestamp <= afterTimestamp:
continue
# Convert to dict and include _createdAt for frontend
stat_dict = stat.model_dump() if hasattr(stat, 'model_dump') else stat.dict()
stat_dict['_createdAt'] = stat_timestamp
items.append({
"type": "stat",
"createdAt": stat_timestamp,
"item": stat
"item": stat_dict
})
# Sort all items by createdAt timestamp for chronological order

View file

@ -94,15 +94,30 @@ class AiService:
Includes billing checks:
- Balance check before AI call
- Provider permission check (via RBAC)
Also stores workflow stats after each successful AI call.
"""
# Billing check before AI call
# Billing check before AI call (validates RBAC permissions)
await self._checkBillingBeforeAiCall()
# Calculate effective allowedProviders: RBAC ∩ Workflow
# RBAC is master - only RBAC-permitted providers can ever be used
effectiveProviders = self._calculateEffectiveProviders()
if effectiveProviders and request.options:
request.options = request.options.model_copy(update={'allowedProviders': effectiveProviders})
logger.debug(f"Effective allowedProviders for AI request: {effectiveProviders}")
if hasattr(request, 'contentParts') and request.contentParts:
return await self.extractionService.processContentPartsWithAi(
response = await self.extractionService.processContentPartsWithAi(
request, self.aiObjects, progressCallback
)
return await self.aiObjects.callWithTextContext(request)
else:
response = await self.aiObjects.callWithTextContext(request)
# Store workflow stats after each AI call
self._storeAiCallStats(response, request)
return response
async def _checkBillingBeforeAiCall(self) -> None:
"""
@ -206,6 +221,87 @@ class AiService:
# Log but don't block on billing check errors
logger.warning(f"Billing check failed with error (non-blocking): {e}")
def _calculateEffectiveProviders(self) -> Optional[List[str]]:
"""
Calculate effective allowed providers: RBAC Workflow.
RBAC is master - only RBAC-permitted providers can ever be used.
If workflow specifies allowedProviders, intersect with RBAC.
If no workflow providers, use all RBAC-permitted providers.
Returns:
List of effective allowed providers, or None if no filtering needed
"""
try:
user = getattr(self.services, 'user', None)
mandateId = getattr(self.services, 'mandateId', None)
if not user or not mandateId:
return None
# Get RBAC-permitted providers (master list)
# Note: getBillingService is imported at module level from mainServiceBilling
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
featureCode = getattr(self.services, 'featureCode', None)
billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
rbacProviders = billingService.getallowedProviders()
if not rbacProviders:
logger.warning("No RBAC-permitted providers found")
return None
# Get workflow-specified providers (optional filter)
workflowProviders = getattr(self.services, 'allowedProviders', None)
if workflowProviders:
# Intersect: only providers that are both RBAC-permitted AND workflow-allowed
effectiveProviders = [p for p in workflowProviders if p in rbacProviders]
logger.debug(f"Provider filter: RBAC={rbacProviders}, Workflow={workflowProviders}, Effective={effectiveProviders}")
else:
# No workflow filter - use all RBAC-permitted providers
effectiveProviders = rbacProviders
logger.debug(f"Provider filter: RBAC={rbacProviders} (no workflow filter)")
return effectiveProviders if effectiveProviders else None
except Exception as e:
logger.warning(f"Error calculating effective providers: {e}")
return None
def _storeAiCallStats(self, response, request: AiCallRequest) -> None:
"""Store workflow stats after an AI call.
This method stores the AI call statistics (cost, processing time, bytes)
to the workflow stats collection for tracking and billing purposes.
Args:
response: AiCallResponse with cost/timing data
request: Original AiCallRequest for context
"""
try:
# Skip if no workflow context
workflow = getattr(self.services, 'workflow', None)
if not workflow or not hasattr(workflow, 'id') or not workflow.id:
logger.debug("No workflow context - skipping stats storage")
return
# Skip if response is an error
if not response or getattr(response, 'errorCount', 0) > 0:
logger.debug("Error response - skipping stats storage")
return
# Determine process name from operation type
opType = getattr(request.options, 'operationType', 'unknown') if request.options else 'unknown'
process = f"ai.call.{opType}"
# Store the stat
self.services.chat.storeWorkflowStat(workflow, response, process)
logger.debug(f"Stored AI call stat: {process}, cost={getattr(response, 'priceCHF', 0):.4f} CHF")
except Exception as e:
# Log but don't fail - stats storage is not critical
logger.debug(f"Could not store AI call stat: {str(e)}")
async def ensureAiObjectsInitialized(self):
"""Ensure aiObjects is initialized and submodules are ready."""
if self.aiObjects is None:
@ -428,7 +524,7 @@ Respond with ONLY a JSON object in this exact format:
# Debug: persist prompt/response for analysis with context-specific naming
debugPrefix = debugType if debugType else "plan"
self.services.utils.writeDebugFile(fullPrompt, f"{debugPrefix}_prompt")
response = await self.aiObjects.callWithTextContext(request)
response = await self.callAi(request) # Use callAi to ensure stats are stored
result = response.content or ""
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response")
return result
@ -485,16 +581,7 @@ Respond with ONLY a JSON object in this exact format:
operationType=opType.value
)
# Try to store workflow stats, but don't fail if workflow is None (e.g., in chatbot context)
try:
self.services.chat.storeWorkflowStat(
self.services.workflow,
response,
f"ai.{opType.name.lower()}"
)
except Exception as e:
# Log but don't fail - workflow might be None in some contexts (e.g., chatbot)
logger.debug(f"Could not store workflow stat (workflow may be None): {str(e)}")
# Note: Stats are now stored centrally in callAi() - no need to duplicate here
self.services.chat.progressLogUpdate(aiOperationId, 0.9, f"{opType.name} completed")
self.services.chat.progressLogFinish(aiOperationId, True)

View file

@ -269,17 +269,7 @@ class AiCallLooper:
# Document generation - save all iteration responses
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response_iteration_{iteration}")
# Emit stats for this iteration (only if workflow exists and has id)
if self.services.workflow and hasattr(self.services.workflow, 'id') and self.services.workflow.id:
try:
self.services.chat.storeWorkflowStat(
self.services.workflow,
response,
f"ai.call.{debugPrefix}.iteration_{iteration}"
)
except Exception as statError:
# Don't break the main loop if stat storage fails
logger.warning(f"Failed to store workflow stat: {str(statError)}")
# Note: Stats are now stored centrally in callAi() - no need to duplicate here
# Check for error response using generic error detection (errorCount > 0 or modelName == "error")
if hasattr(response, 'errorCount') and response.errorCount > 0:

View file

@ -101,11 +101,7 @@ class ImageGenerationPath:
operationType=OperationTypeEnum.IMAGE_GENERATE.value
)
self.services.chat.storeWorkflowStat(
self.services.workflow,
response,
"ai.generate.image"
)
# Note: Stats are now stored centrally in callAi() - no need to duplicate here
self.services.chat.progressLogUpdate(imageOperationId, 0.9, "Image generated")
self.services.chat.progressLogFinish(imageOperationId, True)

View file

@ -42,14 +42,10 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
try:
services = getServices(currentUser, mandateId=mandateId)
# Store preferred providers in services context for billing/model selection
# Support both preferredProviders (list) and legacy preferredProvider (string)
if hasattr(userInput, 'preferredProviders') and userInput.preferredProviders:
services.preferredProviders = userInput.preferredProviders
logger.debug(f"Using preferred providers: {userInput.preferredProviders}")
elif hasattr(userInput, 'preferredProvider') and userInput.preferredProvider:
services.preferredProviders = [userInput.preferredProvider]
logger.debug(f"Using preferred provider (legacy): {userInput.preferredProvider}")
# Store allowedProviders in services context for model selection
if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders:
services.allowedProviders = userInput.allowedProviders
logger.info(f"AI provider filter active: {userInput.allowedProviders}")
# Store feature instance ID in services context
if featureInstanceId:

View file

@ -14,7 +14,8 @@ from modules.workflows.processing.modes.modeDynamic import DynamicMode
from modules.workflows.processing.modes.modeAutomation import AutomationMode
from modules.workflows.processing.shared.stateTools import checkWorkflowStopped
from modules.datamodels.datamodelAi import OperationTypeEnum, PriorityEnum, ProcessingModeEnum, AiCallOptions, AiCallRequest
from modules.shared.jsonUtils import extractJsonString, repairBrokenJson
from modules.shared.jsonUtils import extractJsonString, repairBrokenJson, parseJsonWithModel
from modules.datamodels.datamodelWorkflow import UnderstandingResult
if TYPE_CHECKING:
from modules.datamodels.datamodelWorkflow import TaskResult
@ -477,8 +478,7 @@ class WorkflowProcessor:
maxProcessingTime=15 # Fast path should complete in 15s
)
# Call AI directly (no document generation - just plain text response)
# Use callWithTextContext() for text-only calls
# Call AI via callAi() to ensure stats are stored
aiRequest = AiCallRequest(
prompt=fastPathPrompt,
context="",
@ -486,7 +486,7 @@ class WorkflowProcessor:
contentParts=None # Fast path doesn't process documents
)
aiCallResponse = await self.services.ai.aiObjects.callWithTextContext(aiRequest)
aiCallResponse = await self.services.ai.callAi(aiRequest)
# Extract response content (AiCallResponse.content is a string)
responseText = aiCallResponse.content if aiCallResponse.content else ""