fixed stats and billing sync
This commit is contained in:
parent
bbea0ff115
commit
e4d41965f3
9 changed files with 151 additions and 47 deletions
|
|
@ -144,6 +144,9 @@ class AiCallOptions(BaseModel):
|
|||
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0, description="Temperature for response generation (0.0-2.0, lower = more consistent)")
|
||||
maxParts: Optional[int] = Field(default=1000, ge=1, le=1000, description="Maximum number of continuation parts to fetch")
|
||||
|
||||
# Provider filtering (from UI multiselect or automation config)
|
||||
allowedProviders: Optional[List[str]] = Field(default=None, description="List of allowed AI providers to use (empty = all RBAC-permitted)")
|
||||
|
||||
|
||||
class AiCallRequest(BaseModel):
|
||||
"""Centralized AI call request payload for interface use."""
|
||||
|
|
|
|||
|
|
@ -403,8 +403,7 @@ class UserInputRequest(BaseModel):
|
|||
listFileId: List[str] = Field(default_factory=list, description="List of file IDs")
|
||||
userLanguage: str = Field(default="en", description="User's preferred language")
|
||||
workflowId: Optional[str] = Field(None, description="Optional ID of the workflow to continue")
|
||||
preferredProvider: Optional[str] = Field(None, description="Preferred AI provider (e.g., 'anthropic', 'openai') - deprecated, use preferredProviders")
|
||||
preferredProviders: Optional[List[str]] = Field(None, description="List of preferred AI providers (multiselect)")
|
||||
allowedProviders: Optional[List[str]] = Field(None, description="List of allowed AI providers (multiselect)")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
|
|
|
|||
|
|
@ -89,6 +89,17 @@ class AiObjects:
|
|||
|
||||
# Get failover models for this operation type
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
|
||||
# Filter by allowedProviders if specified (from workflow config)
|
||||
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
||||
if allowedProviders:
|
||||
filteredModels = [m for m in availableModels if m.connectorType in allowedProviders]
|
||||
if filteredModels:
|
||||
logger.info(f"Filtered models by allowedProviders {allowedProviders}: {len(filteredModels)} models (from {len(availableModels)})")
|
||||
availableModels = filteredModels
|
||||
else:
|
||||
logger.warning(f"No models match allowedProviders {allowedProviders}, using all {len(availableModels)} available models")
|
||||
|
||||
failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels)
|
||||
|
||||
if not failoverModelList:
|
||||
|
|
|
|||
|
|
@ -1532,8 +1532,19 @@ class ChatObjects:
|
|||
return []
|
||||
|
||||
# Return all stats records sorted by creation time
|
||||
stats.sort(key=lambda x: x.get("created_at", ""))
|
||||
return [ChatStat(**stat) for stat in stats]
|
||||
# DB uses _createdAt (camelCase system field)
|
||||
stats.sort(key=lambda x: x.get("_createdAt", 0))
|
||||
|
||||
# Convert to ChatStat objects, preserving _createdAt via extra="allow"
|
||||
result = []
|
||||
for stat in stats:
|
||||
chat_stat = ChatStat(**stat)
|
||||
# Explicitly preserve _createdAt from raw DB record
|
||||
if "_createdAt" in stat:
|
||||
setattr(chat_stat, '_createdAt', stat["_createdAt"])
|
||||
result.append(chat_stat)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def createStat(self, statData: Dict[str, Any]) -> ChatStat:
|
||||
|
|
@ -1549,9 +1560,16 @@ class ChatObjects:
|
|||
# Validate the stat data against ChatStat model
|
||||
stat = ChatStat(**statData)
|
||||
|
||||
logger.debug(f"Creating stat for workflow {statData.get('workflowId')}: "
|
||||
f"process={statData.get('process')}, "
|
||||
f"priceCHF={statData.get('priceCHF', 0):.4f}, "
|
||||
f"processingTime={statData.get('processingTime', 0):.2f}s")
|
||||
|
||||
# Create the stat record in the database
|
||||
created = self.db.recordCreate(ChatStat, stat)
|
||||
|
||||
logger.info(f"Created stat {created.get('id')} for workflow {statData.get('workflowId')}")
|
||||
|
||||
# Return the created ChatStat
|
||||
return ChatStat(**created)
|
||||
except Exception as e:
|
||||
|
|
@ -1643,10 +1661,14 @@ class ChatObjects:
|
|||
if afterTimestamp is not None and stat_timestamp <= afterTimestamp:
|
||||
continue
|
||||
|
||||
# Convert to dict and include _createdAt for frontend
|
||||
stat_dict = stat.model_dump() if hasattr(stat, 'model_dump') else stat.dict()
|
||||
stat_dict['_createdAt'] = stat_timestamp
|
||||
|
||||
items.append({
|
||||
"type": "stat",
|
||||
"createdAt": stat_timestamp,
|
||||
"item": stat
|
||||
"item": stat_dict
|
||||
})
|
||||
|
||||
# Sort all items by createdAt timestamp for chronological order
|
||||
|
|
|
|||
|
|
@ -94,15 +94,30 @@ class AiService:
|
|||
Includes billing checks:
|
||||
- Balance check before AI call
|
||||
- Provider permission check (via RBAC)
|
||||
|
||||
Also stores workflow stats after each successful AI call.
|
||||
"""
|
||||
# Billing check before AI call
|
||||
# Billing check before AI call (validates RBAC permissions)
|
||||
await self._checkBillingBeforeAiCall()
|
||||
|
||||
# Calculate effective allowedProviders: RBAC ∩ Workflow
|
||||
# RBAC is master - only RBAC-permitted providers can ever be used
|
||||
effectiveProviders = self._calculateEffectiveProviders()
|
||||
if effectiveProviders and request.options:
|
||||
request.options = request.options.model_copy(update={'allowedProviders': effectiveProviders})
|
||||
logger.debug(f"Effective allowedProviders for AI request: {effectiveProviders}")
|
||||
|
||||
if hasattr(request, 'contentParts') and request.contentParts:
|
||||
return await self.extractionService.processContentPartsWithAi(
|
||||
response = await self.extractionService.processContentPartsWithAi(
|
||||
request, self.aiObjects, progressCallback
|
||||
)
|
||||
return await self.aiObjects.callWithTextContext(request)
|
||||
else:
|
||||
response = await self.aiObjects.callWithTextContext(request)
|
||||
|
||||
# Store workflow stats after each AI call
|
||||
self._storeAiCallStats(response, request)
|
||||
|
||||
return response
|
||||
|
||||
async def _checkBillingBeforeAiCall(self) -> None:
|
||||
"""
|
||||
|
|
@ -206,6 +221,87 @@ class AiService:
|
|||
# Log but don't block on billing check errors
|
||||
logger.warning(f"Billing check failed with error (non-blocking): {e}")
|
||||
|
||||
def _calculateEffectiveProviders(self) -> Optional[List[str]]:
|
||||
"""
|
||||
Calculate effective allowed providers: RBAC ∩ Workflow.
|
||||
|
||||
RBAC is master - only RBAC-permitted providers can ever be used.
|
||||
If workflow specifies allowedProviders, intersect with RBAC.
|
||||
If no workflow providers, use all RBAC-permitted providers.
|
||||
|
||||
Returns:
|
||||
List of effective allowed providers, or None if no filtering needed
|
||||
"""
|
||||
try:
|
||||
user = getattr(self.services, 'user', None)
|
||||
mandateId = getattr(self.services, 'mandateId', None)
|
||||
|
||||
if not user or not mandateId:
|
||||
return None
|
||||
|
||||
# Get RBAC-permitted providers (master list)
|
||||
# Note: getBillingService is imported at module level from mainServiceBilling
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
featureCode = getattr(self.services, 'featureCode', None)
|
||||
billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
|
||||
rbacProviders = billingService.getallowedProviders()
|
||||
|
||||
if not rbacProviders:
|
||||
logger.warning("No RBAC-permitted providers found")
|
||||
return None
|
||||
|
||||
# Get workflow-specified providers (optional filter)
|
||||
workflowProviders = getattr(self.services, 'allowedProviders', None)
|
||||
|
||||
if workflowProviders:
|
||||
# Intersect: only providers that are both RBAC-permitted AND workflow-allowed
|
||||
effectiveProviders = [p for p in workflowProviders if p in rbacProviders]
|
||||
logger.debug(f"Provider filter: RBAC={rbacProviders}, Workflow={workflowProviders}, Effective={effectiveProviders}")
|
||||
else:
|
||||
# No workflow filter - use all RBAC-permitted providers
|
||||
effectiveProviders = rbacProviders
|
||||
logger.debug(f"Provider filter: RBAC={rbacProviders} (no workflow filter)")
|
||||
|
||||
return effectiveProviders if effectiveProviders else None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating effective providers: {e}")
|
||||
return None
|
||||
|
||||
def _storeAiCallStats(self, response, request: AiCallRequest) -> None:
|
||||
"""Store workflow stats after an AI call.
|
||||
|
||||
This method stores the AI call statistics (cost, processing time, bytes)
|
||||
to the workflow stats collection for tracking and billing purposes.
|
||||
|
||||
Args:
|
||||
response: AiCallResponse with cost/timing data
|
||||
request: Original AiCallRequest for context
|
||||
"""
|
||||
try:
|
||||
# Skip if no workflow context
|
||||
workflow = getattr(self.services, 'workflow', None)
|
||||
if not workflow or not hasattr(workflow, 'id') or not workflow.id:
|
||||
logger.debug("No workflow context - skipping stats storage")
|
||||
return
|
||||
|
||||
# Skip if response is an error
|
||||
if not response or getattr(response, 'errorCount', 0) > 0:
|
||||
logger.debug("Error response - skipping stats storage")
|
||||
return
|
||||
|
||||
# Determine process name from operation type
|
||||
opType = getattr(request.options, 'operationType', 'unknown') if request.options else 'unknown'
|
||||
process = f"ai.call.{opType}"
|
||||
|
||||
# Store the stat
|
||||
self.services.chat.storeWorkflowStat(workflow, response, process)
|
||||
logger.debug(f"Stored AI call stat: {process}, cost={getattr(response, 'priceCHF', 0):.4f} CHF")
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail - stats storage is not critical
|
||||
logger.debug(f"Could not store AI call stat: {str(e)}")
|
||||
|
||||
async def ensureAiObjectsInitialized(self):
|
||||
"""Ensure aiObjects is initialized and submodules are ready."""
|
||||
if self.aiObjects is None:
|
||||
|
|
@ -428,7 +524,7 @@ Respond with ONLY a JSON object in this exact format:
|
|||
# Debug: persist prompt/response for analysis with context-specific naming
|
||||
debugPrefix = debugType if debugType else "plan"
|
||||
self.services.utils.writeDebugFile(fullPrompt, f"{debugPrefix}_prompt")
|
||||
response = await self.aiObjects.callWithTextContext(request)
|
||||
response = await self.callAi(request) # Use callAi to ensure stats are stored
|
||||
result = response.content or ""
|
||||
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response")
|
||||
return result
|
||||
|
|
@ -485,16 +581,7 @@ Respond with ONLY a JSON object in this exact format:
|
|||
operationType=opType.value
|
||||
)
|
||||
|
||||
# Try to store workflow stats, but don't fail if workflow is None (e.g., in chatbot context)
|
||||
try:
|
||||
self.services.chat.storeWorkflowStat(
|
||||
self.services.workflow,
|
||||
response,
|
||||
f"ai.{opType.name.lower()}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't fail - workflow might be None in some contexts (e.g., chatbot)
|
||||
logger.debug(f"Could not store workflow stat (workflow may be None): {str(e)}")
|
||||
# Note: Stats are now stored centrally in callAi() - no need to duplicate here
|
||||
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.9, f"{opType.name} completed")
|
||||
self.services.chat.progressLogFinish(aiOperationId, True)
|
||||
|
|
|
|||
|
|
@ -269,17 +269,7 @@ class AiCallLooper:
|
|||
# Document generation - save all iteration responses
|
||||
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response_iteration_{iteration}")
|
||||
|
||||
# Emit stats for this iteration (only if workflow exists and has id)
|
||||
if self.services.workflow and hasattr(self.services.workflow, 'id') and self.services.workflow.id:
|
||||
try:
|
||||
self.services.chat.storeWorkflowStat(
|
||||
self.services.workflow,
|
||||
response,
|
||||
f"ai.call.{debugPrefix}.iteration_{iteration}"
|
||||
)
|
||||
except Exception as statError:
|
||||
# Don't break the main loop if stat storage fails
|
||||
logger.warning(f"Failed to store workflow stat: {str(statError)}")
|
||||
# Note: Stats are now stored centrally in callAi() - no need to duplicate here
|
||||
|
||||
# Check for error response using generic error detection (errorCount > 0 or modelName == "error")
|
||||
if hasattr(response, 'errorCount') and response.errorCount > 0:
|
||||
|
|
|
|||
|
|
@ -101,11 +101,7 @@ class ImageGenerationPath:
|
|||
operationType=OperationTypeEnum.IMAGE_GENERATE.value
|
||||
)
|
||||
|
||||
self.services.chat.storeWorkflowStat(
|
||||
self.services.workflow,
|
||||
response,
|
||||
"ai.generate.image"
|
||||
)
|
||||
# Note: Stats are now stored centrally in callAi() - no need to duplicate here
|
||||
|
||||
self.services.chat.progressLogUpdate(imageOperationId, 0.9, "Image generated")
|
||||
self.services.chat.progressLogFinish(imageOperationId, True)
|
||||
|
|
|
|||
|
|
@ -42,14 +42,10 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
|
|||
try:
|
||||
services = getServices(currentUser, mandateId=mandateId)
|
||||
|
||||
# Store preferred providers in services context for billing/model selection
|
||||
# Support both preferredProviders (list) and legacy preferredProvider (string)
|
||||
if hasattr(userInput, 'preferredProviders') and userInput.preferredProviders:
|
||||
services.preferredProviders = userInput.preferredProviders
|
||||
logger.debug(f"Using preferred providers: {userInput.preferredProviders}")
|
||||
elif hasattr(userInput, 'preferredProvider') and userInput.preferredProvider:
|
||||
services.preferredProviders = [userInput.preferredProvider]
|
||||
logger.debug(f"Using preferred provider (legacy): {userInput.preferredProvider}")
|
||||
# Store allowedProviders in services context for model selection
|
||||
if hasattr(userInput, 'allowedProviders') and userInput.allowedProviders:
|
||||
services.allowedProviders = userInput.allowedProviders
|
||||
logger.info(f"AI provider filter active: {userInput.allowedProviders}")
|
||||
|
||||
# Store feature instance ID in services context
|
||||
if featureInstanceId:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@ from modules.workflows.processing.modes.modeDynamic import DynamicMode
|
|||
from modules.workflows.processing.modes.modeAutomation import AutomationMode
|
||||
from modules.workflows.processing.shared.stateTools import checkWorkflowStopped
|
||||
from modules.datamodels.datamodelAi import OperationTypeEnum, PriorityEnum, ProcessingModeEnum, AiCallOptions, AiCallRequest
|
||||
from modules.shared.jsonUtils import extractJsonString, repairBrokenJson
|
||||
from modules.shared.jsonUtils import extractJsonString, repairBrokenJson, parseJsonWithModel
|
||||
from modules.datamodels.datamodelWorkflow import UnderstandingResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from modules.datamodels.datamodelWorkflow import TaskResult
|
||||
|
|
@ -477,8 +478,7 @@ class WorkflowProcessor:
|
|||
maxProcessingTime=15 # Fast path should complete in 15s
|
||||
)
|
||||
|
||||
# Call AI directly (no document generation - just plain text response)
|
||||
# Use callWithTextContext() for text-only calls
|
||||
# Call AI via callAi() to ensure stats are stored
|
||||
aiRequest = AiCallRequest(
|
||||
prompt=fastPathPrompt,
|
||||
context="",
|
||||
|
|
@ -486,7 +486,7 @@ class WorkflowProcessor:
|
|||
contentParts=None # Fast path doesn't process documents
|
||||
)
|
||||
|
||||
aiCallResponse = await self.services.ai.aiObjects.callWithTextContext(aiRequest)
|
||||
aiCallResponse = await self.services.ai.callAi(aiRequest)
|
||||
|
||||
# Extract response content (AiCallResponse.content is a string)
|
||||
responseText = aiCallResponse.content if aiCallResponse.content else ""
|
||||
|
|
|
|||
Loading…
Reference in a new issue