fixed stats and billing sync

This commit is contained in:
patrick-motsch 2026-02-08 01:44:43 +01:00
parent bbea0ff115
commit e4d41965f3
9 changed files with 151 additions and 47 deletions

View file

@ -144,6 +144,9 @@ class AiCallOptions(BaseModel):
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0, description="Temperature for response generation (0.0-2.0, lower = more consistent)") temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0, description="Temperature for response generation (0.0-2.0, lower = more consistent)")
maxParts: Optional[int] = Field(default=1000, ge=1, le=1000, description="Maximum number of continuation parts to fetch") maxParts: Optional[int] = Field(default=1000, ge=1, le=1000, description="Maximum number of continuation parts to fetch")
# Provider filtering (from UI multiselect or automation config)
allowedProviders: Optional[List[str]] = Field(default=None, description="List of allowed AI providers to use (empty = all RBAC-permitted)")
class AiCallRequest(BaseModel): class AiCallRequest(BaseModel):
"""Centralized AI call request payload for interface use.""" """Centralized AI call request payload for interface use."""

View file

@ -403,8 +403,7 @@ class UserInputRequest(BaseModel):
listFileId: List[str] = Field(default_factory=list, description="List of file IDs") listFileId: List[str] = Field(default_factory=list, description="List of file IDs")
userLanguage: str = Field(default="en", description="User's preferred language") userLanguage: str = Field(default="en", description="User's preferred language")
workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue") workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue")
preferredProvider: Optional[str] = Field(None, description="Preferred AI provider (e.g., 'anthropic', 'openai') - deprecated, use preferredProviders") allowedProviders: Optional[List[str]] = Field(None, description="List of allowed AI providers (multiselect)")
preferredProviders: Optional[List[str]] = Field(None, description="List of preferred AI providers (multiselect)")
registerModelLabels( registerModelLabels(

View file

@ -89,6 +89,17 @@ class AiObjects:
# Get failover models for this operation type # Get failover models for this operation type
availableModels = modelRegistry.getAvailableModels() availableModels = modelRegistry.getAvailableModels()
# Filter by allowedProviders if specified (from workflow config)
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
if allowedProviders:
filteredModels = [m for m in availableModels if m.connectorType in allowedProviders]
if filteredModels:
logger.info(f"Filtered models by allowedProviders {allowedProviders}: {len(filteredModels)} models (from {len(availableModels)})")
availableModels = filteredModels
else:
logger.warning(f"No models match allowedProviders {allowedProviders}, using all {len(availableModels)} available models")
failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels) failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels)
if not failoverModelList: if not failoverModelList:

View file

@ -1532,8 +1532,19 @@ class ChatObjects:
return [] return []
# Return all stats records sorted by creation time # Return all stats records sorted by creation time
stats.sort(key=lambda x: x.get("created_at", "")) # DB uses _createdAt (camelCase system field)
return [ChatStat(**stat) for stat in stats] stats.sort(key=lambda x: x.get("_createdAt", 0))
# Convert to ChatStat objects, preserving _createdAt via extra="allow"
result = []
for stat in stats:
chat_stat = ChatStat(**stat)
# Explicitly preserve _createdAt from raw DB record
if "_createdAt" in stat:
setattr(chat_stat, '_createdAt', stat["_createdAt"])
result.append(chat_stat)
return result
def createStat(self, statData: Dict[str, Any]) -> ChatStat: def createStat(self, statData: Dict[str, Any]) -> ChatStat:
@ -1549,9 +1560,16 @@ class ChatObjects:
# Validate the stat data against ChatStat model # Validate the stat data against ChatStat model
stat = ChatStat(**statData) stat = ChatStat(**statData)
logger.debug(f"Creating stat for workflow {statData.get('workflowId')}: "
f"process={statData.get('process')}, "
f"priceCHF={statData.get('priceCHF', 0):.4f}, "
f"processingTime={statData.get('processingTime', 0):.2f}s")
# Create the stat record in the database # Create the stat record in the database
created = self.db.recordCreate(ChatStat, stat) created = self.db.recordCreate(ChatStat, stat)
logger.info(f"Created stat {created.get('id')} for workflow {statData.get('workflowId')}")
# Return the created ChatStat # Return the created ChatStat
return ChatStat(**created) return ChatStat(**created)
except Exception as e: except Exception as e:
@ -1643,10 +1661,14 @@ class ChatObjects:
if afterTimestamp is not None and stat_timestamp <= afterTimestamp: if afterTimestamp is not None and stat_timestamp <= afterTimestamp:
continue continue
# Convert to dict and include _createdAt for frontend
stat_dict = stat.model_dump() if hasattr(stat, 'model_dump') else stat.dict()
stat_dict['_createdAt'] = stat_timestamp
items.append({ items.append({
"type": "stat", "type": "stat",
"createdAt": stat_timestamp, "createdAt": stat_timestamp,
"item": stat "item": stat_dict
}) })
# Sort all items by createdAt timestamp for chronological order # Sort all items by createdAt timestamp for chronological order

View file

@ -94,15 +94,30 @@ class AiService:
Includes billing checks: Includes billing checks:
- Balance check before AI call - Balance check before AI call
- Provider permission check (via RBAC) - Provider permission check (via RBAC)
Also stores workflow stats after each successful AI call.
""" """
# Billing check before AI call # Billing check before AI call (validates RBAC permissions)
await self._checkBillingBeforeAiCall() await self._checkBillingBeforeAiCall()
# Calculate effective allowedProviders: RBAC ∩ Workflow
# RBAC is master - only RBAC-permitted providers can ever be used
effectiveProviders = self._calculateEffectiveProviders()
if effectiveProviders and request.options:
request.options = request.options.model_copy(update={'allowedProviders': effectiveProviders})
logger.debug(f"Effective allowedProviders for AI request: {effectiveProviders}")
if hasattr(request, 'contentParts') and request.contentParts: if hasattr(request, 'contentParts') and request.contentParts:
return await self.extractionService.processContentPartsWithAi( response = await self.extractionService.processContentPartsWithAi(
request, self.aiObjects, progressCallback request, self.aiObjects, progressCallback
) )
return await self.aiObjects.callWithTextContext(request) else:
response = await self.aiObjects.callWithTextContext(request)
# Store workflow stats after each AI call
self._storeAiCallStats(response, request)
return response
async def _checkBillingBeforeAiCall(self) -> None: async def _checkBillingBeforeAiCall(self) -> None:
""" """
@ -206,6 +221,87 @@ class AiService:
# Log but don't block on billing check errors # Log but don't block on billing check errors
logger.warning(f"Billing check failed with error (non-blocking): {e}") logger.warning(f"Billing check failed with error (non-blocking): {e}")
def _calculateEffectiveProviders(self) -> Optional[List[str]]:
"""
Calculate effective allowed providers: RBAC Workflow.
RBAC is master - only RBAC-permitted providers can ever be used.
If workflow specifies allowedProviders, intersect with RBAC.
If no workflow providers, use all RBAC-permitted providers.
Returns:
List of effective allowed providers, or None if no filtering needed
"""
try:
user = getattr(self.services, 'user', None)
mandateId = getattr(self.services, 'mandateId', None)
if not user or not mandateId:
return None
# Get RBAC-permitted providers (master list)
# Note: getBillingService is imported at module level from mainServiceBilling
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
featureCode = getattr(self.services, 'featureCode', None)
billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
rbacProviders = billingService.getallowedProviders()
if not rbacProviders:
logger.warning("No RBAC-permitted providers found")
return None
# Get workflow-specified providers (optional filter)
workflowProviders = getattr(self.services, 'allowedProviders', None)
if workflowProviders:
# Intersect: only providers that are both RBAC-permitted AND workflow-allowed
effectiveProviders = [p for p in workflowProviders if p in rbacProviders]
logger.debug(f"Provider filter: RBAC={rbacProviders}, Workflow={workflowProviders}, Effective={effectiveProviders}")
else:
# No workflow filter - use all RBAC-permitted providers
effectiveProviders = rbacProviders
logger.debug(f"Provider filter: RBAC={rbacProviders} (no workflow filter)")
return effectiveProviders if effectiveProviders else None
except Exception as e:
logger.warning(f"Error calculating effective providers: {e}")
return None
def _storeAiCallStats(self, response, request: AiCallRequest) -> None:
"""Store workflow stats after an AI call.
This method stores the AI call statistics (cost, processing time, bytes)
to the workflow stats collection for tracking and billing purposes.
Args:
response: AiCallResponse with cost/timing data
request: Original AiCallRequest for context
"""
try:
# Skip if no workflow context
workflow = getattr(self.services, 'workflow', None)
if not workflow or not hasattr(workflow, 'id') or not workflow.id:
logger.debug("No workflow context - skipping stats storage")
return
# Skip if response is an error
if not response or getattr(response, 'errorCount', 0) > 0:
logger.debug("Error response - skipping stats storage")
return
# Determine process name from operation type
opType = getattr(request.options, 'operationType', 'unknown') if request.options else 'unknown'
process = f"ai.call.{opType}"
# Store the stat
self.services.chat.storeWorkflowStat(workflow, response, process)
logger.debug(f"Stored AI call stat: {process}, cost={getattr(response, 'priceCHF', 0):.4f} CHF")
except Exception as e:
# Log but don't fail - stats storage is not critical
logger.debug(f"Could not store AI call stat: {str(e)}")
async def ensureAiObjectsInitialized(self): async def ensureAiObjectsInitialized(self):
"""Ensure aiObjects is initialized and submodules are ready.""" """Ensure aiObjects is initialized and submodules are ready."""
if self.aiObjects is None: if self.aiObjects is None:
@ -428,7 +524,7 @@ Respond with ONLY a JSON object in this exact format:
# Debug: persist prompt/response for analysis with context-specific naming # Debug: persist prompt/response for analysis with context-specific naming
debugPrefix = debugType if debugType else "plan" debugPrefix = debugType if debugType else "plan"
self.services.utils.writeDebugFile(fullPrompt, f"{debugPrefix}_prompt") self.services.utils.writeDebugFile(fullPrompt, f"{debugPrefix}_prompt")
response = await self.aiObjects.callWithTextContext(request) response = await self.callAi(request) # Use callAi to ensure stats are stored
result = response.content or "" result = response.content or ""
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response") self.services.utils.writeDebugFile(result, f"{debugPrefix}_response")
return result return result
@ -485,16 +581,7 @@ Respond with ONLY a JSON object in this exact format:
operationType=opType.value operationType=opType.value
) )
# Try to store workflow stats, but don't fail if workflow is None (e.g., in chatbot context) # Note: Stats are now stored centrally in callAi() - no need to duplicate here
try:
self.services.chat.storeWorkflowStat(
self.services.workflow,
response,
f"ai.{opType.name.lower()}"
)
except Exception as e:
# Log but don't fail - workflow might be None in some contexts (e.g., chatbot)
logger.debug(f"Could not store workflow stat (workflow may be None): {str(e)}")
self.services.chat.progressLogUpdate(aiOperationId, 0.9, f"{opType.name} completed") self.services.chat.progressLogUpdate(aiOperationId, 0.9, f"{opType.name} completed")
self.services.chat.progressLogFinish(aiOperationId, True) self.services.chat.progressLogFinish(aiOperationId, True)

View file

@ -269,17 +269,7 @@ class AiCallLooper:
# Document generation - save all iteration responses # Document generation - save all iteration responses
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response_iteration_{iteration}") self.services.utils.writeDebugFile(result, f"{debugPrefix}_response_iteration_{iteration}")
# Emit stats for this iteration (only if workflow exists and has id) # Note: Stats are now stored centrally in callAi() - no need to duplicate here
if self.services.workflow and hasattr(self.services.workflow, 'id') and self.services.workflow.id:
try:
self.services.chat.storeWorkflowStat(
self.services.workflow,
response,
f"ai.call.{debugPrefix}.iteration_{iteration}"
)
except Exception as statError:
# Don't break the main loop if stat storage fails
logger.warning(f"Failed to store workflow stat: {str(statError)}")
# Check for error response using generic error detection (errorCount > 0 or modelName == "error") # Check for error response using generic error detection (errorCount > 0 or modelName == "error")
if hasattr(response, 'errorCount') and response.errorCount > 0: if hasattr(response, 'errorCount') and response.errorCount > 0:

View file

@ -101,11 +101,7 @@ class ImageGenerationPath:
operationType=OperationTypeEnum.IMAGE_GENERATE.value operationType=OperationTypeEnum.IMAGE_GENERATE.value
) )
self.services.chat.storeWorkflowStat( # Note: Stats are now stored centrally in callAi() - no need to duplicate here
self.services.workflow,
response,
"ai.generate.image"
)
self.services.chat.progressLogUpdate(imageOperationId, 0.9, "Image generated") self.services.chat.progressLogUpdate(imageOperationId, 0.9, "Image generated")
self.services.chat.progressLogFinish(imageOperationId, True) self.services.chat.progressLogFinish(imageOperationId, True)

View file

@ -42,14 +42,10 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
try: try:
services = getServices(currentUser, mandateId=mandateId) services = getServices(currentUser, mandateId=mandateId)
# Store preferred providers in services context for billing/model selection # Store allowedProviders in services context for model selection
# Support both preferredProviders (list) and legacy preferredProvider (string) if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders:
if hasattr(userInput, 'preferredProviders') and userInput.preferredProviders: services.allowedProviders = userInput.allowedProviders
services.preferredProviders = userInput.preferredProviders logger.info(f"AI provider filter active: {userInput.allowedProviders}")
logger.debug(f"Using preferred providers: {userInput.preferredProviders}")
elif hasattr(userInput, 'preferredProvider') and userInput.preferredProvider:
services.preferredProviders = [userInput.preferredProvider]
logger.debug(f"Using preferred provider (legacy): {userInput.preferredProvider}")
# Store feature instance ID in services context # Store feature instance ID in services context
if featureInstanceId: if featureInstanceId:

View file

@ -14,7 +14,8 @@ from modules.workflows.processing.modes.modeDynamic import DynamicMode
from modules.workflows.processing.modes.modeAutomation import AutomationMode from modules.workflows.processing.modes.modeAutomation import AutomationMode
from modules.workflows.processing.shared.stateTools import checkWorkflowStopped from modules.workflows.processing.shared.stateTools import checkWorkflowStopped
from modules.datamodels.datamodelAi import OperationTypeEnum, PriorityEnum, ProcessingModeEnum, AiCallOptions, AiCallRequest from modules.datamodels.datamodelAi import OperationTypeEnum, PriorityEnum, ProcessingModeEnum, AiCallOptions, AiCallRequest
from modules.shared.jsonUtils import extractJsonString, repairBrokenJson from modules.shared.jsonUtils import extractJsonString, repairBrokenJson, parseJsonWithModel
from modules.datamodels.datamodelWorkflow import UnderstandingResult
if TYPE_CHECKING: if TYPE_CHECKING:
from modules.datamodels.datamodelWorkflow import TaskResult from modules.datamodels.datamodelWorkflow import TaskResult
@ -477,8 +478,7 @@ class WorkflowProcessor:
maxProcessingTime=15 # Fast path should complete in 15s maxProcessingTime=15 # Fast path should complete in 15s
) )
# Call AI directly (no document generation - just plain text response) # Call AI via callAi() to ensure stats are stored
# Use callWithTextContext() for text-only calls
aiRequest = AiCallRequest( aiRequest = AiCallRequest(
prompt=fastPathPrompt, prompt=fastPathPrompt,
context="", context="",
@ -486,7 +486,7 @@ class WorkflowProcessor:
contentParts=None # Fast path doesn't process documents contentParts=None # Fast path doesn't process documents
) )
aiCallResponse = await self.services.ai.aiObjects.callWithTextContext(aiRequest) aiCallResponse = await self.services.ai.callAi(aiRequest)
# Extract response content (AiCallResponse.content is a string) # Extract response content (AiCallResponse.content is a string)
responseText = aiCallResponse.content if aiCallResponse.content else "" responseText = aiCallResponse.content if aiCallResponse.content else ""