# Copyright (c) 2025 Patrick Motsch # All rights reserved. import logging import asyncio import uuid import base64 from typing import Dict, Any, List, Union, Tuple, Optional, Callable, AsyncGenerator from dataclasses import dataclass, field import time logger = logging.getLogger(__name__) from modules.aicore.aicoreModelRegistry import modelRegistry from modules.aicore.aicoreModelSelector import modelSelector from modules.datamodels.datamodelAi import ( AiModel, AiCallOptions, AiCallRequest, AiCallResponse, OperationTypeEnum, AiModelCall, AiModelResponse, ) from modules.datamodels.datamodelExtraction import ContentPart, MergeStrategy # Dynamic model registry - models are now loaded from connectors via aicore system @dataclass(slots=True) class AiObjects: """Centralized AI interface: dynamically discovers and uses AI models. billingCallback: Set by serviceAi before AI calls. Called after EVERY individual model call with the AiCallResponse. This ensures per-model-call billing with exact provider + model name. The callback handles billing recording. """ billingCallback: Optional[Callable] = field(default=None, repr=False) def __post_init__(self) -> None: # Auto-discover and register all available connectors self._discoverAndRegisterConnectors() def _discoverAndRegisterConnectors(self): """Auto-discover and register all available AI connectors.""" logger.info("Auto-discovering AI connectors...") # Use the model registry's built-in discovery mechanism discoveredConnectors = modelRegistry.discoverConnectors() # Register each discovered connector for connector in discoveredConnectors: modelRegistry.registerConnector(connector) logger.info(f"Registered connector: {connector.getConnectorType()}") logger.info(f"Total connectors registered: {len(discoveredConnectors)}") logger.info("All AI connectors registered with dynamic model registry") @classmethod async def create(cls) -> "AiObjects": """Create AiObjects instance with auto-discovered connectors.""" # No need to manually create connectors - they're auto-discovered return cls() def _selectModel(self, prompt: str, context: str, options: AiCallOptions) -> str: """Select the best model using dynamic model selection system. Returns displayName (unique identifier).""" # Get available models from the dynamic registry availableModels = modelRegistry.getAvailableModels() if not availableModels: logger.error("No models available in the registry") raise ValueError("No AI models available") # Use the dynamic model selector selectedModel = modelSelector.selectModel(prompt, context, options, availableModels) if not selectedModel: logger.error("No suitable model found for the given criteria") raise ValueError("No suitable AI model found") logger.info(f"Selected model: {selectedModel.name} ({selectedModel.displayName})") return selectedModel.displayName # AI for Extraction, Processing, Generation async def callWithTextContext(self, request: AiCallRequest) -> AiCallResponse: """Call AI model for traditional text/context calls with fallback mechanism. Supports two modes: - Legacy: prompt + context → constructs messages internally - Agent: request.messages provided → passes through directly """ prompt = request.prompt context = request.context or "" options = request.options # Get failover models for this operation type availableModels = modelRegistry.getAvailableModels() # Filter by allowedProviders if specified (from workflow config) allowedProviders = getattr(options, 'allowedProviders', None) if options else None if allowedProviders: filteredModels = [m for m in availableModels if m.connectorType in allowedProviders] if filteredModels: logger.info(f"Filtered models by allowedProviders {allowedProviders}: {len(filteredModels)} models (from {len(availableModels)})") availableModels = filteredModels else: logger.warning(f"No models match allowedProviders {allowedProviders}, using all {len(availableModels)} available models") failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels) if not failoverModelList: errorMsg = f"No suitable models found for operation {options.operationType}" logger.error(errorMsg) return AiCallResponse( content=errorMsg, modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1 ) # Try each model in failover sequence lastError = None for attempt, model in enumerate(failoverModelList): try: logger.info(f"Attempting AI call with model: {model.name} (attempt {attempt + 1}/{len(failoverModelList)})") if request.messages: response = await self._callWithMessages(model, request.messages, options, request.tools) else: response = await self._callWithModel(model, prompt, context, options) logger.info(f"AI call successful with model: {model.name}") return response except Exception as e: lastError = e logger.warning(f"AI call failed with model {model.name}: {str(e)}") modelSelector.reportFailure(model.name) if attempt < len(failoverModelList) - 1: logger.info(f"Trying next failover model...") continue else: logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}") break # All failover attempts failed - return error response errorMsg = f"All AI models failed for operation {options.operationType}. Last error: {str(lastError)}" logger.error(errorMsg) return AiCallResponse( content=errorMsg, modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1 ) def _createErrorResponse(self, errorMsg: str, inputBytes: int, outputBytes: int) -> AiCallResponse: """Create an error response.""" return AiCallResponse( content=errorMsg, modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=inputBytes, bytesReceived=outputBytes, errorCount=1 ) async def _callWithModel(self, model: AiModel, prompt: str, context: str, options: AiCallOptions = None) -> AiCallResponse: """Call a specific model and return the response.""" # Calculate input bytes from prompt and context inputBytes = len((prompt + context).encode('utf-8')) # Replace placeholder with model's maxTokens value if "" in prompt: if model.maxTokens > 0: tokenLimit = str(model.maxTokens) modelPrompt = prompt.replace("", tokenLimit) logger.debug(f"Replaced with {tokenLimit} for model {model.name}") else: raise ValueError(f"Model {model.name} has invalid maxTokens ({model.maxTokens}). Cannot set token limit.") else: modelPrompt = prompt # Update messages array with replaced content messages = [] if context: messages.append({"role": "system", "content": f"Context from documents:\n{context}"}) messages.append({"role": "user", "content": modelPrompt}) # Start timing startTime = time.time() # Call the model's function directly - completely generic if model.functionCall: # Create standardized call object modelCall = AiModelCall( messages=messages, model=model, options=options or {} ) # Log before calling model contextSize = len(context.encode('utf-8')) if context else 0 promptSize = len(modelPrompt.encode('utf-8')) if modelPrompt else 0 totalInputSize = contextSize + promptSize logger.debug(f"Calling model {model.name} with {len(messages)} messages, context size: {contextSize} bytes, prompt size: {promptSize} bytes, total input: {totalInputSize} bytes") # Call the model with standardized interface modelResponse = await model.functionCall(modelCall) # Log after successful call logger.debug(f"Model {model.name} returned successfully") # Extract content from standardized response if not modelResponse.success: raise ValueError(f"Model call failed: {modelResponse.error}") content = modelResponse.content else: raise ValueError(f"Model {model.name} has no function call defined") # Calculate timing and output bytes endTime = time.time() processingTime = endTime - startTime outputBytes = len(content.encode("utf-8")) # Calculate price using model's own price calculation method priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes) response = AiCallResponse( content=content, modelName=model.name, provider=model.connectorType, priceCHF=priceCHF, processingTime=processingTime, bytesSent=inputBytes, bytesReceived=outputBytes, errorCount=0 ) # BILLING: Record billing for THIS specific model call # billingCallback is set by serviceAi and records one billing transaction # per model call with exact provider + model name if self.billingCallback: try: self.billingCallback(response) except Exception as e: logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}") return response async def _callWithMessages(self, model: AiModel, messages: List[Dict[str, Any]], options: AiCallOptions = None, tools: List[Dict[str, Any]] = None) -> AiCallResponse: """Call a model with pre-built messages (agent mode). Supports tools for native function calling.""" import json as _json inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages) startTime = time.time() if not model.functionCall: raise ValueError(f"Model {model.name} has no function call defined") modelCall = AiModelCall( messages=messages, model=model, options=options or {}, tools=tools ) modelResponse = await model.functionCall(modelCall) if not modelResponse.success: raise ValueError(f"Model call failed: {modelResponse.error}") endTime = time.time() processingTime = endTime - startTime content = modelResponse.content outputBytes = len(content.encode("utf-8")) priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes) # Extract tool calls from metadata if present (native function calling) responseToolCalls = None if modelResponse.metadata: responseToolCalls = modelResponse.metadata.get("toolCalls") response = AiCallResponse( content=content, modelName=model.name, provider=model.connectorType, priceCHF=priceCHF, processingTime=processingTime, bytesSent=inputBytes, bytesReceived=outputBytes, errorCount=0, toolCalls=responseToolCalls ) if self.billingCallback: try: self.billingCallback(response) except Exception as e: logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}") return response async def callWithTextContextStream( self, request: AiCallRequest ) -> AsyncGenerator[Union[str, AiCallResponse], None]: """Streaming variant of callWithTextContext. Yields str deltas, then final AiCallResponse.""" options = request.options availableModels = modelRegistry.getAvailableModels() allowedProviders = getattr(options, 'allowedProviders', None) if options else None if allowedProviders: filtered = [m for m in availableModels if m.connectorType in allowedProviders] if filtered: availableModels = filtered failoverModelList = modelSelector.getFailoverModelList( request.prompt, request.context or "", options, availableModels ) if not failoverModelList: yield AiCallResponse( content=f"No suitable models found for operation {options.operationType}", modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1, ) return lastError = None for attempt, model in enumerate(failoverModelList): try: logger.info(f"Streaming AI call with model: {model.name} (attempt {attempt + 1})") async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools): yield chunk return except Exception as e: lastError = e logger.warning(f"Streaming AI call failed with {model.name}: {e}") modelSelector.reportFailure(model.name) if attempt < len(failoverModelList) - 1: continue break yield AiCallResponse( content=f"All models failed (stream). Last error: {lastError}", modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1, ) async def _callWithMessagesStream( self, model: AiModel, messages: List[Dict[str, Any]], options: AiCallOptions = None, tools: List[Dict[str, Any]] = None, ) -> AsyncGenerator[Union[str, AiCallResponse], None]: """Stream a model call. Yields str deltas, then final AiCallResponse with billing.""" from modules.datamodels.datamodelAi import AiModelCall, AiModelResponse inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages) startTime = time.time() if not model.functionCallStream: response = await self._callWithMessages(model, messages, options, tools) if response.content: yield response.content yield response return modelCall = AiModelCall( messages=messages, model=model, options=options or {}, tools=tools, ) finalModelResponse = None async for item in model.functionCallStream(modelCall): if isinstance(item, AiModelResponse): finalModelResponse = item else: yield item if not finalModelResponse: raise ValueError(f"Stream from {model.name} produced no final AiModelResponse") endTime = time.time() processingTime = endTime - startTime content = finalModelResponse.content outputBytes = len(content.encode("utf-8")) priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes) responseToolCalls = None if finalModelResponse.metadata: responseToolCalls = finalModelResponse.metadata.get("toolCalls") response = AiCallResponse( content=content, modelName=model.name, provider=model.connectorType, priceCHF=priceCHF, processingTime=processingTime, bytesSent=inputBytes, bytesReceived=outputBytes, errorCount=0, toolCalls=responseToolCalls, ) if self.billingCallback: try: self.billingCallback(response) except Exception as e: logger.error(f"BILLING: Failed to record stream billing for {model.name}: {e}") yield response async def callEmbedding(self, texts: List[str], options: AiCallOptions = None) -> AiCallResponse: """Generate embeddings for a list of texts using the best available embedding model. Uses the standard model selector with OperationTypeEnum.EMBEDDING to pick the model. Failover across providers (OpenAI → Mistral) works identically to chat models. Returns: AiCallResponse with metadata["embeddings"] containing the vectors. """ if options is None: options = AiCallOptions(operationType=OperationTypeEnum.EMBEDDING) else: options.operationType = OperationTypeEnum.EMBEDDING combinedText = " ".join(texts[:3])[:500] availableModels = modelRegistry.getAvailableModels() failoverModelList = modelSelector.getFailoverModelList( combinedText, "", options, availableModels ) if not failoverModelList: return AiCallResponse( content="", modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1 ) lastError = None for attempt, model in enumerate(failoverModelList): try: logger.info(f"Embedding call with {model.name} (attempt {attempt + 1}/{len(failoverModelList)})") inputBytes = sum(len(t.encode("utf-8")) for t in texts) startTime = time.time() modelCall = AiModelCall( model=model, options=options, embeddingInput=texts ) modelResponse = await model.functionCall(modelCall) if not modelResponse.success: raise ValueError(f"Embedding call failed: {modelResponse.error}") processingTime = time.time() - startTime priceCHF = model.calculatepriceCHF(processingTime, inputBytes, 0) embeddings = (modelResponse.metadata or {}).get("embeddings", []) response = AiCallResponse( content="", modelName=model.name, provider=model.connectorType, priceCHF=priceCHF, processingTime=processingTime, bytesSent=inputBytes, bytesReceived=0, errorCount=0, metadata={"embeddings": embeddings} ) if self.billingCallback: try: self.billingCallback(response) except Exception as e: logger.error(f"BILLING: Failed to record billing for embedding {model.name}: {e}") return response except Exception as e: lastError = e logger.warning(f"Embedding call failed with {model.name}: {str(e)}") modelSelector.reportFailure(model.name) if attempt < len(failoverModelList) - 1: continue break errorMsg = f"All embedding models failed. Last error: {str(lastError)}" logger.error(errorMsg) return AiCallResponse( content=errorMsg, modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1 ) # Utility methods async def listAvailableModels(self, connectorType: str = None) -> List[Dict[str, Any]]: """List available models, optionally filtered by connector type.""" models = modelRegistry.getAvailableModels() if connectorType: return [model.model_dump() for model in models if model.connectorType == connectorType] return [model.model_dump() for model in models] async def getModelInfo(self, displayName: str) -> Dict[str, Any]: """Get information about a specific model by displayName.""" model = modelRegistry.getModel(displayName) if not model: raise ValueError(f"Model with displayName '{displayName}' not found") return model.model_dump() async def getModelsByTag(self, tag: str) -> List[str]: """Get model displayNames that have a specific tag. Returns displayNames (unique identifiers).""" models = modelRegistry.getModelsByTag(tag) return [model.displayName for model in models]