gateway/modules/interfaces/interfaceAiObjects.py
2026-03-15 23:38:21 +01:00

517 lines
21 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
import logging
import asyncio
import uuid
import base64
from typing import Dict, Any, List, Union, Tuple, Optional, Callable, AsyncGenerator
from dataclasses import dataclass, field
import time
logger = logging.getLogger(__name__)
from modules.aicore.aicoreModelRegistry import modelRegistry
from modules.aicore.aicoreModelSelector import modelSelector
from modules.datamodels.datamodelAi import (
AiModel,
AiCallOptions,
AiCallRequest,
AiCallResponse,
OperationTypeEnum,
AiModelCall,
AiModelResponse,
)
from modules.datamodels.datamodelExtraction import ContentPart, MergeStrategy
# Dynamic model registry - models are now loaded from connectors via aicore system
@dataclass(slots=True)
class AiObjects:
"""Centralized AI interface: dynamically discovers and uses AI models.
billingCallback: Set by serviceAi before AI calls. Called after EVERY individual
model call with the AiCallResponse. This ensures per-model-call billing with
exact provider + model name. The callback handles billing recording.
"""
billingCallback: Optional[Callable] = field(default=None, repr=False)
def __post_init__(self) -> None:
# Auto-discover and register all available connectors
self._discoverAndRegisterConnectors()
def _discoverAndRegisterConnectors(self):
"""Auto-discover and register all available AI connectors."""
logger.info("Auto-discovering AI connectors...")
# Use the model registry's built-in discovery mechanism
discoveredConnectors = modelRegistry.discoverConnectors()
# Register each discovered connector
for connector in discoveredConnectors:
modelRegistry.registerConnector(connector)
logger.info(f"Registered connector: {connector.getConnectorType()}")
logger.info(f"Total connectors registered: {len(discoveredConnectors)}")
logger.info("All AI connectors registered with dynamic model registry")
@classmethod
async def create(cls) -> "AiObjects":
"""Create AiObjects instance with auto-discovered connectors."""
# No need to manually create connectors - they're auto-discovered
return cls()
def _selectModel(self, prompt: str, context: str, options: AiCallOptions) -> str:
"""Select the best model using dynamic model selection system. Returns displayName (unique identifier)."""
# Get available models from the dynamic registry
availableModels = modelRegistry.getAvailableModels()
if not availableModels:
logger.error("No models available in the registry")
raise ValueError("No AI models available")
# Use the dynamic model selector
selectedModel = modelSelector.selectModel(prompt, context, options, availableModels)
if not selectedModel:
logger.error("No suitable model found for the given criteria")
raise ValueError("No suitable AI model found")
logger.info(f"Selected model: {selectedModel.name} ({selectedModel.displayName})")
return selectedModel.displayName
# AI for Extraction, Processing, Generation
async def callWithTextContext(self, request: AiCallRequest) -> AiCallResponse:
"""Call AI model for traditional text/context calls with fallback mechanism.
Supports two modes:
- Legacy: prompt + context → constructs messages internally
- Agent: request.messages provided → passes through directly
"""
prompt = request.prompt
context = request.context or ""
options = request.options
# Get failover models for this operation type
availableModels = modelRegistry.getAvailableModels()
# Filter by allowedProviders if specified (from workflow config)
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
if allowedProviders:
filteredModels = [m for m in availableModels if m.connectorType in allowedProviders]
if filteredModels:
logger.info(f"Filtered models by allowedProviders {allowedProviders}: {len(filteredModels)} models (from {len(availableModels)})")
availableModels = filteredModels
else:
logger.warning(f"No models match allowedProviders {allowedProviders}, using all {len(availableModels)} available models")
failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels)
if not failoverModelList:
errorMsg = f"No suitable models found for operation {options.operationType}"
logger.error(errorMsg)
return AiCallResponse(
content=errorMsg,
modelName="error",
priceCHF=0.0,
processingTime=0.0,
bytesSent=0,
bytesReceived=0,
errorCount=1
)
# Try each model in failover sequence
lastError = None
for attempt, model in enumerate(failoverModelList):
try:
logger.info(f"Attempting AI call with model: {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
if request.messages:
response = await self._callWithMessages(model, request.messages, options, request.tools)
else:
response = await self._callWithModel(model, prompt, context, options)
logger.info(f"AI call successful with model: {model.name}")
return response
except Exception as e:
lastError = e
logger.warning(f"AI call failed with model {model.name}: {str(e)}")
modelSelector.reportFailure(model.name)
if attempt < len(failoverModelList) - 1:
logger.info(f"Trying next failover model...")
continue
else:
logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}")
break
# All failover attempts failed - return error response
errorMsg = f"All AI models failed for operation {options.operationType}. Last error: {str(lastError)}"
logger.error(errorMsg)
return AiCallResponse(
content=errorMsg,
modelName="error",
priceCHF=0.0,
processingTime=0.0,
bytesSent=0,
bytesReceived=0,
errorCount=1
)
def _createErrorResponse(self, errorMsg: str, inputBytes: int, outputBytes: int) -> AiCallResponse:
"""Create an error response."""
return AiCallResponse(
content=errorMsg,
modelName="error",
priceCHF=0.0,
processingTime=0.0,
bytesSent=inputBytes,
bytesReceived=outputBytes,
errorCount=1
)
async def _callWithModel(self, model: AiModel, prompt: str, context: str, options: AiCallOptions = None) -> AiCallResponse:
"""Call a specific model and return the response."""
# Calculate input bytes from prompt and context
inputBytes = len((prompt + context).encode('utf-8'))
# Replace <TOKEN_LIMIT> placeholder with model's maxTokens value
if "<TOKEN_LIMIT>" in prompt:
if model.maxTokens > 0:
tokenLimit = str(model.maxTokens)
modelPrompt = prompt.replace("<TOKEN_LIMIT>", tokenLimit)
logger.debug(f"Replaced <TOKEN_LIMIT> with {tokenLimit} for model {model.name}")
else:
raise ValueError(f"Model {model.name} has invalid maxTokens ({model.maxTokens}). Cannot set token limit.")
else:
modelPrompt = prompt
# Update messages array with replaced content
messages = []
if context:
messages.append({"role": "system", "content": f"Context from documents:\n{context}"})
messages.append({"role": "user", "content": modelPrompt})
# Start timing
startTime = time.time()
# Call the model's function directly - completely generic
if model.functionCall:
# Create standardized call object
modelCall = AiModelCall(
messages=messages,
model=model,
options=options or {}
)
# Log before calling model
contextSize = len(context.encode('utf-8')) if context else 0
promptSize = len(modelPrompt.encode('utf-8')) if modelPrompt else 0
totalInputSize = contextSize + promptSize
logger.debug(f"Calling model {model.name} with {len(messages)} messages, context size: {contextSize} bytes, prompt size: {promptSize} bytes, total input: {totalInputSize} bytes")
# Call the model with standardized interface
modelResponse = await model.functionCall(modelCall)
# Log after successful call
logger.debug(f"Model {model.name} returned successfully")
# Extract content from standardized response
if not modelResponse.success:
raise ValueError(f"Model call failed: {modelResponse.error}")
content = modelResponse.content
else:
raise ValueError(f"Model {model.name} has no function call defined")
# Calculate timing and output bytes
endTime = time.time()
processingTime = endTime - startTime
outputBytes = len(content.encode("utf-8"))
# Calculate price using model's own price calculation method
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
response = AiCallResponse(
content=content,
modelName=model.name,
provider=model.connectorType,
priceCHF=priceCHF,
processingTime=processingTime,
bytesSent=inputBytes,
bytesReceived=outputBytes,
errorCount=0
)
# BILLING: Record billing for THIS specific model call
# billingCallback is set by serviceAi and records one billing transaction
# per model call with exact provider + model name
if self.billingCallback:
try:
self.billingCallback(response)
except Exception as e:
logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}")
return response
async def _callWithMessages(self, model: AiModel, messages: List[Dict[str, Any]],
options: AiCallOptions = None,
tools: List[Dict[str, Any]] = None) -> AiCallResponse:
"""Call a model with pre-built messages (agent mode). Supports tools for native function calling."""
import json as _json
inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages)
startTime = time.time()
if not model.functionCall:
raise ValueError(f"Model {model.name} has no function call defined")
modelCall = AiModelCall(
messages=messages,
model=model,
options=options or {},
tools=tools
)
modelResponse = await model.functionCall(modelCall)
if not modelResponse.success:
raise ValueError(f"Model call failed: {modelResponse.error}")
endTime = time.time()
processingTime = endTime - startTime
content = modelResponse.content
outputBytes = len(content.encode("utf-8"))
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
# Extract tool calls from metadata if present (native function calling)
responseToolCalls = None
if modelResponse.metadata:
responseToolCalls = modelResponse.metadata.get("toolCalls")
response = AiCallResponse(
content=content,
modelName=model.name,
provider=model.connectorType,
priceCHF=priceCHF,
processingTime=processingTime,
bytesSent=inputBytes,
bytesReceived=outputBytes,
errorCount=0,
toolCalls=responseToolCalls
)
if self.billingCallback:
try:
self.billingCallback(response)
except Exception as e:
logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}")
return response
async def callWithTextContextStream(
self, request: AiCallRequest
) -> AsyncGenerator[Union[str, AiCallResponse], None]:
"""Streaming variant of callWithTextContext. Yields str deltas, then final AiCallResponse."""
options = request.options
availableModels = modelRegistry.getAvailableModels()
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
if allowedProviders:
filtered = [m for m in availableModels if m.connectorType in allowedProviders]
if filtered:
availableModels = filtered
failoverModelList = modelSelector.getFailoverModelList(
request.prompt, request.context or "", options, availableModels
)
if not failoverModelList:
yield AiCallResponse(
content=f"No suitable models found for operation {options.operationType}",
modelName="error", priceCHF=0.0, processingTime=0.0,
bytesSent=0, bytesReceived=0, errorCount=1,
)
return
lastError = None
for attempt, model in enumerate(failoverModelList):
try:
logger.info(f"Streaming AI call with model: {model.name} (attempt {attempt + 1})")
async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools):
yield chunk
return
except Exception as e:
lastError = e
logger.warning(f"Streaming AI call failed with {model.name}: {e}")
modelSelector.reportFailure(model.name)
if attempt < len(failoverModelList) - 1:
continue
break
yield AiCallResponse(
content=f"All models failed (stream). Last error: {lastError}",
modelName="error", priceCHF=0.0, processingTime=0.0,
bytesSent=0, bytesReceived=0, errorCount=1,
)
async def _callWithMessagesStream(
self, model: AiModel, messages: List[Dict[str, Any]],
options: AiCallOptions = None, tools: List[Dict[str, Any]] = None,
) -> AsyncGenerator[Union[str, AiCallResponse], None]:
"""Stream a model call. Yields str deltas, then final AiCallResponse with billing."""
from modules.datamodels.datamodelAi import AiModelCall, AiModelResponse
inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages)
startTime = time.time()
if not model.functionCallStream:
response = await self._callWithMessages(model, messages, options, tools)
if response.content:
yield response.content
yield response
return
modelCall = AiModelCall(
messages=messages, model=model,
options=options or {}, tools=tools,
)
finalModelResponse = None
async for item in model.functionCallStream(modelCall):
if isinstance(item, AiModelResponse):
finalModelResponse = item
else:
yield item
if not finalModelResponse:
raise ValueError(f"Stream from {model.name} produced no final AiModelResponse")
endTime = time.time()
processingTime = endTime - startTime
content = finalModelResponse.content
outputBytes = len(content.encode("utf-8"))
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
responseToolCalls = None
if finalModelResponse.metadata:
responseToolCalls = finalModelResponse.metadata.get("toolCalls")
response = AiCallResponse(
content=content,
modelName=model.name,
provider=model.connectorType,
priceCHF=priceCHF,
processingTime=processingTime,
bytesSent=inputBytes,
bytesReceived=outputBytes,
errorCount=0,
toolCalls=responseToolCalls,
)
if self.billingCallback:
try:
self.billingCallback(response)
except Exception as e:
logger.error(f"BILLING: Failed to record stream billing for {model.name}: {e}")
yield response
async def callEmbedding(self, texts: List[str], options: AiCallOptions = None) -> AiCallResponse:
"""Generate embeddings for a list of texts using the best available embedding model.
Uses the standard model selector with OperationTypeEnum.EMBEDDING to pick the model.
Failover across providers (OpenAI → Mistral) works identically to chat models.
Returns:
AiCallResponse with metadata["embeddings"] containing the vectors.
"""
if options is None:
options = AiCallOptions(operationType=OperationTypeEnum.EMBEDDING)
else:
options.operationType = OperationTypeEnum.EMBEDDING
combinedText = " ".join(texts[:3])[:500]
availableModels = modelRegistry.getAvailableModels()
failoverModelList = modelSelector.getFailoverModelList(
combinedText, "", options, availableModels
)
if not failoverModelList:
return AiCallResponse(
content="", modelName="error", priceCHF=0.0,
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
)
lastError = None
for attempt, model in enumerate(failoverModelList):
try:
logger.info(f"Embedding call with {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
inputBytes = sum(len(t.encode("utf-8")) for t in texts)
startTime = time.time()
modelCall = AiModelCall(
model=model, options=options, embeddingInput=texts
)
modelResponse = await model.functionCall(modelCall)
if not modelResponse.success:
raise ValueError(f"Embedding call failed: {modelResponse.error}")
processingTime = time.time() - startTime
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, 0)
embeddings = (modelResponse.metadata or {}).get("embeddings", [])
response = AiCallResponse(
content="", modelName=model.name, provider=model.connectorType,
priceCHF=priceCHF, processingTime=processingTime,
bytesSent=inputBytes, bytesReceived=0, errorCount=0,
metadata={"embeddings": embeddings}
)
if self.billingCallback:
try:
self.billingCallback(response)
except Exception as e:
logger.error(f"BILLING: Failed to record billing for embedding {model.name}: {e}")
return response
except Exception as e:
lastError = e
logger.warning(f"Embedding call failed with {model.name}: {str(e)}")
modelSelector.reportFailure(model.name)
if attempt < len(failoverModelList) - 1:
continue
break
errorMsg = f"All embedding models failed. Last error: {str(lastError)}"
logger.error(errorMsg)
return AiCallResponse(
content=errorMsg, modelName="error", priceCHF=0.0,
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
)
# Utility methods
async def listAvailableModels(self, connectorType: str = None) -> List[Dict[str, Any]]:
"""List available models, optionally filtered by connector type."""
models = modelRegistry.getAvailableModels()
if connectorType:
return [model.model_dump() for model in models if model.connectorType == connectorType]
return [model.model_dump() for model in models]
async def getModelInfo(self, displayName: str) -> Dict[str, Any]:
"""Get information about a specific model by displayName."""
model = modelRegistry.getModel(displayName)
if not model:
raise ValueError(f"Model with displayName '{displayName}' not found")
return model.model_dump()
async def getModelsByTag(self, tag: str) -> List[str]:
"""Get model displayNames that have a specific tag. Returns displayNames (unique identifiers)."""
models = modelRegistry.getModelsByTag(tag)
return [model.displayName for model in models]