# Copyright (c) 2025 Patrick Motsch # All rights reserved. import logging import asyncio import uuid import base64 from typing import Dict, Any, List, Union, Tuple, Optional, Callable, AsyncGenerator from dataclasses import dataclass, field import time logger = logging.getLogger(__name__) from modules.aicore.aicoreModelRegistry import modelRegistry from modules.aicore.aicoreModelSelector import modelSelector from modules.aicore.aicoreBase import RateLimitExceededException from modules.datamodels.datamodelAi import ( AiModel, AiCallOptions, AiCallRequest, AiCallResponse, OperationTypeEnum, AiModelCall, AiModelResponse, ) from modules.datamodels.datamodelExtraction import ContentPart, MergeStrategy # Dynamic model registry - models are now loaded from connectors via aicore system @dataclass(slots=True) class AiObjects: """Centralized AI interface: dynamically discovers and uses AI models. billingCallback: Set by serviceAi before AI calls. Called after EVERY individual model call with the AiCallResponse. This ensures per-model-call billing with exact provider + model name. The callback handles billing recording. """ billingCallback: Optional[Callable] = field(default=None, repr=False) def __post_init__(self) -> None: # Auto-discover and register all available connectors self._discoverAndRegisterConnectors() def _discoverAndRegisterConnectors(self): """Auto-discover and register all available AI connectors.""" logger.info("Auto-discovering AI connectors...") # Use the model registry's built-in discovery mechanism discoveredConnectors = modelRegistry.discoverConnectors() # Register each discovered connector for connector in discoveredConnectors: modelRegistry.registerConnector(connector) logger.info(f"Registered connector: {connector.getConnectorType()}") logger.info(f"Total connectors registered: {len(discoveredConnectors)}") logger.info("All AI connectors registered with dynamic model registry") @classmethod async def create(cls) -> "AiObjects": """Create AiObjects instance with auto-discovered connectors.""" # No need to manually create connectors - they're auto-discovered return cls() def _selectModel(self, prompt: str, context: str, options: AiCallOptions) -> str: """Select the best model using dynamic model selection system. Returns displayName (unique identifier).""" # Get available models from the dynamic registry availableModels = modelRegistry.getAvailableModels() if not availableModels: logger.error("No models available in the registry") raise ValueError("No AI models available") # Use the dynamic model selector selectedModel = modelSelector.selectModel(prompt, context, options, availableModels) if not selectedModel: logger.error("No suitable model found for the given criteria") raise ValueError("No suitable AI model found") logger.info(f"Selected model: {selectedModel.name} ({selectedModel.displayName})") return selectedModel.displayName # AI for Extraction, Processing, Generation async def callWithTextContext(self, request: AiCallRequest) -> AiCallResponse: """Call AI model for traditional text/context calls with fallback mechanism. Supports two modes: - Legacy: prompt + context → constructs messages internally - Agent: request.messages provided → passes through directly """ prompt = request.prompt context = request.context or "" options = request.options # Get failover models for this operation type availableModels = modelRegistry.getAvailableModels() allowedProviders = getattr(options, 'allowedProviders', None) if options else None if allowedProviders: filteredModels = [m for m in availableModels if m.connectorType in allowedProviders] if filteredModels: availableModels = filteredModels else: errorMsg = f"No models match allowedProviders {allowedProviders} for operation {options.operationType}" logger.error(errorMsg) return AiCallResponse( content=errorMsg, modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1, ) failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels) if not failoverModelList: errorMsg = f"No suitable models found for operation {options.operationType}" logger.error(errorMsg) return AiCallResponse( content=errorMsg, modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1 ) _MAX_SHORT_RETRY = 15.0 lastError = None for attempt, model in enumerate(failoverModelList): try: logger.info(f"Attempting AI call with model: {model.name} (attempt {attempt + 1}/{len(failoverModelList)})") if request.messages: response = await self._callWithMessages(model, request.messages, options, request.tools, toolChoice=request.toolChoice) else: response = await self._callWithModel(model, prompt, context, options) logger.info(f"AI call successful with model: {model.name}") return response except RateLimitExceededException as rle: retryAfter = rle.retryAfterSeconds lastError = rle if 0 < retryAfter <= _MAX_SHORT_RETRY: logger.info(f"Rate limit on {model.name}, waiting {retryAfter:.1f}s before retry") await asyncio.sleep(retryAfter + 0.5) try: if request.messages: response = await self._callWithMessages(model, request.messages, options, request.tools, toolChoice=request.toolChoice) else: response = await self._callWithModel(model, prompt, context, options) logger.info(f"AI call successful with {model.name} after rate-limit retry") return response except Exception as retryErr: lastError = retryErr logger.warning(f"Retry after rate-limit wait also failed for {model.name}: {retryErr}") else: logger.warning(f"Rate limit on {model.name} (retryAfter={retryAfter:.1f}s), failing over") cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0 modelSelector.reportFailure(model.name, cooldownSeconds=cooldown) if attempt < len(failoverModelList) - 1: continue logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}") break except Exception as e: lastError = e logger.warning(f"AI call failed with model {model.name}: {str(e)}") modelSelector.reportFailure(model.name) if attempt < len(failoverModelList) - 1: logger.info(f"Trying next failover model...") continue else: logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}") break # All failover attempts failed - return error response errorMsg = f"All AI models failed for operation {options.operationType}. Last error: {str(lastError)}" logger.error(errorMsg) return AiCallResponse( content=errorMsg, modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1 ) def _createErrorResponse(self, errorMsg: str, inputBytes: int, outputBytes: int) -> AiCallResponse: """Create an error response.""" return AiCallResponse( content=errorMsg, modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=inputBytes, bytesReceived=outputBytes, errorCount=1 ) async def _callWithModel(self, model: AiModel, prompt: str, context: str, options: AiCallOptions = None) -> AiCallResponse: """Call a specific model and return the response.""" # Calculate input bytes from prompt and context inputBytes = len((prompt + context).encode('utf-8')) # Replace placeholder with model's maxTokens value if "" in prompt: if model.maxTokens > 0: tokenLimit = str(model.maxTokens) modelPrompt = prompt.replace("", tokenLimit) logger.debug(f"Replaced with {tokenLimit} for model {model.name}") else: raise ValueError(f"Model {model.name} has invalid maxTokens ({model.maxTokens}). Cannot set token limit.") else: modelPrompt = prompt # Update messages array with replaced content messages = [] if context: messages.append({"role": "system", "content": f"Context from documents:\n{context}"}) messages.append({"role": "user", "content": modelPrompt}) # Start timing startTime = time.time() # Call the model's function directly - completely generic if model.functionCall: # Create standardized call object modelCall = AiModelCall( messages=messages, model=model, options=options or {} ) # Log before calling model contextSize = len(context.encode('utf-8')) if context else 0 promptSize = len(modelPrompt.encode('utf-8')) if modelPrompt else 0 totalInputSize = contextSize + promptSize logger.debug(f"Calling model {model.name} with {len(messages)} messages, context size: {contextSize} bytes, prompt size: {promptSize} bytes, total input: {totalInputSize} bytes") # Call the model with standardized interface modelResponse = await model.functionCall(modelCall) # Log after successful call logger.debug(f"Model {model.name} returned successfully") # Extract content from standardized response if not modelResponse.success: raise ValueError(f"Model call failed: {modelResponse.error}") content = modelResponse.content else: raise ValueError(f"Model {model.name} has no function call defined") # Calculate timing and output bytes endTime = time.time() processingTime = endTime - startTime outputBytes = len(content.encode("utf-8")) # Calculate price using model's own price calculation method priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes) response = AiCallResponse( content=content, modelName=model.name, provider=model.connectorType, priceCHF=priceCHF, processingTime=processingTime, bytesSent=inputBytes, bytesReceived=outputBytes, errorCount=0 ) # BILLING: Record billing for THIS specific model call # billingCallback is set by serviceAi and records one billing transaction # per model call with exact provider + model name if self.billingCallback: try: self.billingCallback(response) except Exception as e: logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}") return response async def _callWithMessages(self, model: AiModel, messages: List[Dict[str, Any]], options: AiCallOptions = None, tools: List[Dict[str, Any]] = None, toolChoice: Any = None) -> AiCallResponse: """Call a model with pre-built messages (agent mode). Supports tools for native function calling.""" import json as _json inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages) startTime = time.time() if not model.functionCall: raise ValueError(f"Model {model.name} has no function call defined") modelCall = AiModelCall( messages=messages, model=model, options=options or {}, tools=tools, toolChoice=toolChoice, ) modelResponse = await model.functionCall(modelCall) if not modelResponse.success: raise ValueError(f"Model call failed: {modelResponse.error}") endTime = time.time() processingTime = endTime - startTime content = modelResponse.content outputBytes = len(content.encode("utf-8")) priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes) # Extract tool calls from metadata if present (native function calling) responseToolCalls = None if modelResponse.metadata: responseToolCalls = modelResponse.metadata.get("toolCalls") response = AiCallResponse( content=content, modelName=model.name, provider=model.connectorType, priceCHF=priceCHF, processingTime=processingTime, bytesSent=inputBytes, bytesReceived=outputBytes, errorCount=0, toolCalls=responseToolCalls ) response._modelMaxTokens = model.maxTokens if self.billingCallback: try: self.billingCallback(response) except Exception as e: logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}") return response async def callWithTextContextStream( self, request: AiCallRequest ) -> AsyncGenerator[Union[str, AiCallResponse], None]: """Streaming variant of callWithTextContext. Yields str deltas, then final AiCallResponse.""" options = request.options availableModels = modelRegistry.getAvailableModels() allowedProviders = getattr(options, 'allowedProviders', None) if options else None if allowedProviders: filtered = [m for m in availableModels if m.connectorType in allowedProviders] if filtered: availableModels = filtered else: yield AiCallResponse( content=f"No models match allowedProviders {allowedProviders} for operation {options.operationType}", modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1, ) return failoverModelList = modelSelector.getFailoverModelList( request.prompt, request.context or "", options, availableModels ) if not failoverModelList: yield AiCallResponse( content=f"No suitable models found for operation {options.operationType}", modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1, ) return _MAX_SHORT_RETRY = 15.0 lastError = None for attempt, model in enumerate(failoverModelList): try: logger.info(f"Streaming AI call with model: {model.name} (attempt {attempt + 1})") async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools, toolChoice=request.toolChoice): yield chunk return except RateLimitExceededException as rle: retryAfter = rle.retryAfterSeconds lastError = rle if 0 < retryAfter <= _MAX_SHORT_RETRY: logger.info(f"Rate limit on {model.name}, waiting {retryAfter:.1f}s before retry") await asyncio.sleep(retryAfter + 0.5) try: async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools, toolChoice=request.toolChoice): yield chunk return except Exception as retryErr: lastError = retryErr logger.warning(f"Retry after rate-limit wait also failed for {model.name}: {retryErr}") else: logger.warning(f"Rate limit on {model.name} (retryAfter={retryAfter:.1f}s), failing over") cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0 modelSelector.reportFailure(model.name, cooldownSeconds=cooldown) if attempt < len(failoverModelList) - 1: continue break except Exception as e: lastError = e logger.warning(f"Streaming AI call failed with {model.name}: {e}") modelSelector.reportFailure(model.name) if attempt < len(failoverModelList) - 1: continue break yield AiCallResponse( content=f"All models failed (stream). Last error: {lastError}", modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1, ) async def _callWithMessagesStream( self, model: AiModel, messages: List[Dict[str, Any]], options: AiCallOptions = None, tools: List[Dict[str, Any]] = None, toolChoice: Any = None, ) -> AsyncGenerator[Union[str, AiCallResponse], None]: """Stream a model call. Yields str deltas, then final AiCallResponse with billing.""" from modules.datamodels.datamodelAi import AiModelCall, AiModelResponse inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages) startTime = time.time() if not model.functionCallStream: response = await self._callWithMessages(model, messages, options, tools, toolChoice=toolChoice) if response.content: yield response.content yield response return modelCall = AiModelCall( messages=messages, model=model, options=options or {}, tools=tools, toolChoice=toolChoice, ) finalModelResponse = None async for item in model.functionCallStream(modelCall): if isinstance(item, AiModelResponse): finalModelResponse = item else: yield item if not finalModelResponse: raise ValueError(f"Stream from {model.name} produced no final AiModelResponse") endTime = time.time() processingTime = endTime - startTime content = finalModelResponse.content outputBytes = len(content.encode("utf-8")) priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes) responseToolCalls = None if finalModelResponse.metadata: responseToolCalls = finalModelResponse.metadata.get("toolCalls") response = AiCallResponse( content=content, modelName=model.name, provider=model.connectorType, priceCHF=priceCHF, processingTime=processingTime, bytesSent=inputBytes, bytesReceived=outputBytes, errorCount=0, toolCalls=responseToolCalls, ) response._modelMaxTokens = model.maxTokens if self.billingCallback: try: self.billingCallback(response) except Exception as e: logger.error(f"BILLING: Failed to record stream billing for {model.name}: {e}") yield response async def callEmbedding(self, texts: List[str], options: AiCallOptions = None) -> AiCallResponse: """Generate embeddings for a list of texts using the best available embedding model. Token-aware batching: splits the texts list into batches that respect the model's contextLength (with 10% safety margin). Each batch is sent as a separate API call; the resulting embeddings are merged in order. Failover across providers (OpenAI -> Mistral) works identically to chat models, but ContextLengthExceededException is NOT retried via failover (same limits). Returns: AiCallResponse with metadata["embeddings"] containing the vectors. """ from modules.aicore.aicoreBase import ContextLengthExceededException as _CtxExc if options is None: options = AiCallOptions(operationType=OperationTypeEnum.EMBEDDING) else: options.operationType = OperationTypeEnum.EMBEDDING combinedText = " ".join(texts[:3])[:500] availableModels = modelRegistry.getAvailableModels() allowedProviders = getattr(options, 'allowedProviders', None) if options else None if allowedProviders: filtered = [m for m in availableModels if m.connectorType in allowedProviders] if filtered: availableModels = filtered else: logger.warning(f"No embedding models match allowedProviders {allowedProviders}") failoverModelList = modelSelector.getFailoverModelList( combinedText, "", options, availableModels ) if not failoverModelList: return AiCallResponse( content="", modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1 ) lastError = None for attempt, model in enumerate(failoverModelList): try: logger.info(f"Embedding call with {model.name} (attempt {attempt + 1}/{len(failoverModelList)})") inputBytes = sum(len(t.encode("utf-8")) for t in texts) startTime = time.time() batches = _buildEmbeddingBatches(texts, model.contextLength) logger.info( f"Embedding: {len(texts)} texts -> {len(batches)} batch(es), " f"model contextLength={model.contextLength}" ) allEmbeddings: List[List[float]] = [] totalPriceCHF = 0.0 for batchIdx, batch in enumerate(batches): modelCall = AiModelCall( model=model, options=options, embeddingInput=batch ) modelResponse = await model.functionCall(modelCall) if not modelResponse.success: raise ValueError(f"Embedding batch {batchIdx + 1} failed: {modelResponse.error}") batchEmbeddings = (modelResponse.metadata or {}).get("embeddings", []) allEmbeddings.extend(batchEmbeddings) batchBytes = sum(len(t.encode("utf-8")) for t in batch) totalPriceCHF += model.calculatepriceCHF(0, batchBytes, 0) processingTime = time.time() - startTime if totalPriceCHF == 0.0: totalPriceCHF = model.calculatepriceCHF(processingTime, inputBytes, 0) response = AiCallResponse( content="", modelName=model.name, provider=model.connectorType, priceCHF=totalPriceCHF, processingTime=processingTime, bytesSent=inputBytes, bytesReceived=0, errorCount=0, metadata={"embeddings": allEmbeddings} ) if self.billingCallback: try: self.billingCallback(response) except Exception as e: logger.error(f"BILLING: Failed to record billing for embedding {model.name}: {e}") return response except _CtxExc as e: logger.error(f"ContextLengthExceeded for {model.name} despite batching – aborting failover: {e}") return AiCallResponse( content=str(e), modelName=model.name, priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1 ) except RateLimitExceededException as rle: retryAfter = rle.retryAfterSeconds lastError = rle cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0 logger.warning(f"Rate limit on {model.name} during embedding (retryAfter={retryAfter:.1f}s)") modelSelector.reportFailure(model.name, cooldownSeconds=cooldown) if attempt < len(failoverModelList) - 1: continue break except Exception as e: lastError = e logger.warning(f"Embedding call failed with {model.name}: {str(e)}") modelSelector.reportFailure(model.name) if attempt < len(failoverModelList) - 1: continue break errorMsg = f"All embedding models failed. Last error: {str(lastError)}" logger.error(errorMsg) return AiCallResponse( content=errorMsg, modelName="error", priceCHF=0.0, processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1 ) # Utility methods async def listAvailableModels(self, connectorType: str = None) -> List[Dict[str, Any]]: """List available models, optionally filtered by connector type.""" models = modelRegistry.getAvailableModels() if connectorType: return [model.model_dump() for model in models if model.connectorType == connectorType] return [model.model_dump() for model in models] async def getModelInfo(self, displayName: str) -> Dict[str, Any]: """Get information about a specific model by displayName.""" model = modelRegistry.getModel(displayName) if not model: raise ValueError(f"Model with displayName '{displayName}' not found") return model.model_dump() async def getModelsByTag(self, tag: str) -> List[str]: """Get model displayNames that have a specific tag. Returns displayNames (unique identifiers).""" models = modelRegistry.getModelsByTag(tag) return [model.displayName for model in models] # ============================================================================= # Internal helpers # ============================================================================= _CHARS_PER_TOKEN = 4 _SAFETY_MARGIN = 0.90 def _estimateTokens(text: str) -> int: """Rough token estimate: 1 token ~ 4 characters.""" return max(1, len(text) // _CHARS_PER_TOKEN) def _buildEmbeddingBatches(texts: List[str], contextLength: int) -> List[List[str]]: """Split a list of texts into batches whose total estimated token count stays within the model's contextLength (with safety margin). Each individual text is assumed to already be within limits (enforced by the chunking layer). If a single text exceeds the budget, it is placed in its own batch as a last resort. """ if not texts: return [] if contextLength <= 0: return [texts] maxTokensPerBatch = int(contextLength * _SAFETY_MARGIN) batches: List[List[str]] = [] currentBatch: List[str] = [] currentTokens = 0 for text in texts: textTokens = _estimateTokens(text) if currentBatch and (currentTokens + textTokens) > maxTokensPerBatch: batches.append(currentBatch) currentBatch = [] currentTokens = 0 currentBatch.append(text) currentTokens += textTokens if currentBatch: batches.append(currentBatch) return batches