678 lines
28 KiB
Python
678 lines
28 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
||
# All rights reserved.
|
||
import logging
|
||
import asyncio
|
||
import uuid
|
||
import base64
|
||
from typing import Dict, Any, List, Union, Tuple, Optional, Callable, AsyncGenerator
|
||
from dataclasses import dataclass, field
|
||
import time
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||
from modules.aicore.aicoreModelSelector import modelSelector
|
||
from modules.aicore.aicoreBase import RateLimitExceededException
|
||
from modules.datamodels.datamodelAi import (
|
||
AiModel,
|
||
AiCallOptions,
|
||
AiCallRequest,
|
||
AiCallResponse,
|
||
OperationTypeEnum,
|
||
AiModelCall,
|
||
AiModelResponse,
|
||
)
|
||
from modules.datamodels.datamodelExtraction import ContentPart, MergeStrategy
|
||
|
||
|
||
# Dynamic model registry - models are now loaded from connectors via aicore system
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class AiObjects:
|
||
"""Centralized AI interface: dynamically discovers and uses AI models.
|
||
|
||
billingCallback: Set by serviceAi before AI calls. Called after EVERY individual
|
||
model call with the AiCallResponse. This ensures per-model-call billing with
|
||
exact provider + model name. The callback handles billing recording.
|
||
"""
|
||
billingCallback: Optional[Callable] = field(default=None, repr=False)
|
||
|
||
def __post_init__(self) -> None:
|
||
# Auto-discover and register all available connectors
|
||
self._discoverAndRegisterConnectors()
|
||
|
||
def _discoverAndRegisterConnectors(self):
|
||
"""Auto-discover and register all available AI connectors."""
|
||
logger.info("Auto-discovering AI connectors...")
|
||
|
||
# Use the model registry's built-in discovery mechanism
|
||
discoveredConnectors = modelRegistry.discoverConnectors()
|
||
|
||
# Register each discovered connector
|
||
for connector in discoveredConnectors:
|
||
modelRegistry.registerConnector(connector)
|
||
logger.info(f"Registered connector: {connector.getConnectorType()}")
|
||
|
||
logger.info(f"Total connectors registered: {len(discoveredConnectors)}")
|
||
logger.info("All AI connectors registered with dynamic model registry")
|
||
|
||
@classmethod
|
||
async def create(cls) -> "AiObjects":
|
||
"""Create AiObjects instance with auto-discovered connectors."""
|
||
# No need to manually create connectors - they're auto-discovered
|
||
return cls()
|
||
|
||
def _selectModel(self, prompt: str, context: str, options: AiCallOptions) -> str:
|
||
"""Select the best model using dynamic model selection system. Returns displayName (unique identifier)."""
|
||
# Get available models from the dynamic registry
|
||
availableModels = modelRegistry.getAvailableModels()
|
||
|
||
if not availableModels:
|
||
logger.error("No models available in the registry")
|
||
raise ValueError("No AI models available")
|
||
|
||
# Use the dynamic model selector
|
||
selectedModel = modelSelector.selectModel(prompt, context, options, availableModels)
|
||
|
||
if not selectedModel:
|
||
logger.error("No suitable model found for the given criteria")
|
||
raise ValueError("No suitable AI model found")
|
||
|
||
logger.info(f"Selected model: {selectedModel.name} ({selectedModel.displayName})")
|
||
return selectedModel.displayName
|
||
|
||
|
||
# AI for Extraction, Processing, Generation
|
||
async def callWithTextContext(self, request: AiCallRequest) -> AiCallResponse:
|
||
"""Call AI model for traditional text/context calls with fallback mechanism.
|
||
|
||
Supports two modes:
|
||
- Legacy: prompt + context → constructs messages internally
|
||
- Agent: request.messages provided → passes through directly
|
||
"""
|
||
prompt = request.prompt
|
||
context = request.context or ""
|
||
options = request.options
|
||
|
||
# Get failover models for this operation type
|
||
availableModels = modelRegistry.getAvailableModels()
|
||
|
||
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
||
if allowedProviders:
|
||
filteredModels = [m for m in availableModels if m.connectorType in allowedProviders]
|
||
if filteredModels:
|
||
availableModels = filteredModels
|
||
else:
|
||
errorMsg = f"No models match allowedProviders {allowedProviders} for operation {options.operationType}"
|
||
logger.error(errorMsg)
|
||
return AiCallResponse(
|
||
content=errorMsg, modelName="error", priceCHF=0.0,
|
||
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1,
|
||
)
|
||
|
||
failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels)
|
||
|
||
if not failoverModelList:
|
||
errorMsg = f"No suitable models found for operation {options.operationType}"
|
||
logger.error(errorMsg)
|
||
return AiCallResponse(
|
||
content=errorMsg,
|
||
modelName="error",
|
||
priceCHF=0.0,
|
||
processingTime=0.0,
|
||
bytesSent=0,
|
||
bytesReceived=0,
|
||
errorCount=1
|
||
)
|
||
|
||
_MAX_SHORT_RETRY = 15.0
|
||
|
||
lastError = None
|
||
for attempt, model in enumerate(failoverModelList):
|
||
try:
|
||
logger.info(f"Attempting AI call with model: {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
|
||
|
||
if request.messages:
|
||
response = await self._callWithMessages(model, request.messages, options, request.tools, toolChoice=request.toolChoice)
|
||
else:
|
||
response = await self._callWithModel(model, prompt, context, options)
|
||
|
||
logger.info(f"AI call successful with model: {model.name}")
|
||
return response
|
||
|
||
except RateLimitExceededException as rle:
|
||
retryAfter = rle.retryAfterSeconds
|
||
lastError = rle
|
||
if 0 < retryAfter <= _MAX_SHORT_RETRY:
|
||
logger.info(f"Rate limit on {model.name}, waiting {retryAfter:.1f}s before retry")
|
||
await asyncio.sleep(retryAfter + 0.5)
|
||
try:
|
||
if request.messages:
|
||
response = await self._callWithMessages(model, request.messages, options, request.tools, toolChoice=request.toolChoice)
|
||
else:
|
||
response = await self._callWithModel(model, prompt, context, options)
|
||
logger.info(f"AI call successful with {model.name} after rate-limit retry")
|
||
return response
|
||
except Exception as retryErr:
|
||
lastError = retryErr
|
||
logger.warning(f"Retry after rate-limit wait also failed for {model.name}: {retryErr}")
|
||
else:
|
||
logger.warning(f"Rate limit on {model.name} (retryAfter={retryAfter:.1f}s), failing over")
|
||
cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0
|
||
modelSelector.reportFailure(model.name, cooldownSeconds=cooldown)
|
||
if attempt < len(failoverModelList) - 1:
|
||
continue
|
||
logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}")
|
||
break
|
||
|
||
except Exception as e:
|
||
lastError = e
|
||
logger.warning(f"AI call failed with model {model.name}: {str(e)}")
|
||
modelSelector.reportFailure(model.name)
|
||
|
||
if attempt < len(failoverModelList) - 1:
|
||
logger.info(f"Trying next failover model...")
|
||
continue
|
||
else:
|
||
logger.error(f"All {len(failoverModelList)} models failed for operation {options.operationType}")
|
||
break
|
||
|
||
# All failover attempts failed - return error response
|
||
errorMsg = f"All AI models failed for operation {options.operationType}. Last error: {str(lastError)}"
|
||
logger.error(errorMsg)
|
||
return AiCallResponse(
|
||
content=errorMsg,
|
||
modelName="error",
|
||
priceCHF=0.0,
|
||
processingTime=0.0,
|
||
bytesSent=0,
|
||
bytesReceived=0,
|
||
errorCount=1
|
||
)
|
||
|
||
def _createErrorResponse(self, errorMsg: str, inputBytes: int, outputBytes: int) -> AiCallResponse:
|
||
"""Create an error response."""
|
||
return AiCallResponse(
|
||
content=errorMsg,
|
||
modelName="error",
|
||
priceCHF=0.0,
|
||
processingTime=0.0,
|
||
bytesSent=inputBytes,
|
||
bytesReceived=outputBytes,
|
||
errorCount=1
|
||
)
|
||
|
||
async def _callWithModel(self, model: AiModel, prompt: str, context: str, options: AiCallOptions = None) -> AiCallResponse:
|
||
"""Call a specific model and return the response."""
|
||
# Calculate input bytes from prompt and context
|
||
inputBytes = len((prompt + context).encode('utf-8'))
|
||
|
||
# Replace <TOKEN_LIMIT> placeholder with model's maxTokens value
|
||
if "<TOKEN_LIMIT>" in prompt:
|
||
if model.maxTokens > 0:
|
||
tokenLimit = str(model.maxTokens)
|
||
modelPrompt = prompt.replace("<TOKEN_LIMIT>", tokenLimit)
|
||
logger.debug(f"Replaced <TOKEN_LIMIT> with {tokenLimit} for model {model.name}")
|
||
else:
|
||
raise ValueError(f"Model {model.name} has invalid maxTokens ({model.maxTokens}). Cannot set token limit.")
|
||
else:
|
||
modelPrompt = prompt
|
||
|
||
# Update messages array with replaced content
|
||
messages = []
|
||
if context:
|
||
messages.append({"role": "system", "content": f"Context from documents:\n{context}"})
|
||
messages.append({"role": "user", "content": modelPrompt})
|
||
|
||
# Start timing
|
||
startTime = time.time()
|
||
|
||
# Call the model's function directly - completely generic
|
||
if model.functionCall:
|
||
# Create standardized call object
|
||
modelCall = AiModelCall(
|
||
messages=messages,
|
||
model=model,
|
||
options=options or {}
|
||
)
|
||
|
||
# Log before calling model
|
||
contextSize = len(context.encode('utf-8')) if context else 0
|
||
promptSize = len(modelPrompt.encode('utf-8')) if modelPrompt else 0
|
||
totalInputSize = contextSize + promptSize
|
||
logger.debug(f"Calling model {model.name} with {len(messages)} messages, context size: {contextSize} bytes, prompt size: {promptSize} bytes, total input: {totalInputSize} bytes")
|
||
|
||
# Call the model with standardized interface
|
||
modelResponse = await model.functionCall(modelCall)
|
||
|
||
# Log after successful call
|
||
logger.debug(f"Model {model.name} returned successfully")
|
||
|
||
# Extract content from standardized response
|
||
if not modelResponse.success:
|
||
raise ValueError(f"Model call failed: {modelResponse.error}")
|
||
content = modelResponse.content
|
||
else:
|
||
raise ValueError(f"Model {model.name} has no function call defined")
|
||
|
||
# Calculate timing and output bytes
|
||
endTime = time.time()
|
||
processingTime = endTime - startTime
|
||
outputBytes = len(content.encode("utf-8"))
|
||
|
||
# Calculate price using model's own price calculation method
|
||
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
|
||
|
||
response = AiCallResponse(
|
||
content=content,
|
||
modelName=model.name,
|
||
provider=model.connectorType,
|
||
priceCHF=priceCHF,
|
||
processingTime=processingTime,
|
||
bytesSent=inputBytes,
|
||
bytesReceived=outputBytes,
|
||
errorCount=0
|
||
)
|
||
|
||
# BILLING: Record billing for THIS specific model call
|
||
# billingCallback is set by serviceAi and records one billing transaction
|
||
# per model call with exact provider + model name
|
||
if self.billingCallback:
|
||
try:
|
||
self.billingCallback(response)
|
||
except Exception as e:
|
||
logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}")
|
||
|
||
return response
|
||
|
||
async def _callWithMessages(self, model: AiModel, messages: List[Dict[str, Any]],
|
||
options: AiCallOptions = None,
|
||
tools: List[Dict[str, Any]] = None,
|
||
toolChoice: Any = None) -> AiCallResponse:
|
||
"""Call a model with pre-built messages (agent mode). Supports tools for native function calling."""
|
||
import json as _json
|
||
|
||
inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages)
|
||
startTime = time.time()
|
||
|
||
if not model.functionCall:
|
||
raise ValueError(f"Model {model.name} has no function call defined")
|
||
|
||
modelCall = AiModelCall(
|
||
messages=messages,
|
||
model=model,
|
||
options=options or {},
|
||
tools=tools,
|
||
toolChoice=toolChoice,
|
||
)
|
||
|
||
modelResponse = await model.functionCall(modelCall)
|
||
|
||
if not modelResponse.success:
|
||
raise ValueError(f"Model call failed: {modelResponse.error}")
|
||
|
||
endTime = time.time()
|
||
processingTime = endTime - startTime
|
||
content = modelResponse.content
|
||
outputBytes = len(content.encode("utf-8"))
|
||
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
|
||
|
||
# Extract tool calls from metadata if present (native function calling)
|
||
responseToolCalls = None
|
||
if modelResponse.metadata:
|
||
responseToolCalls = modelResponse.metadata.get("toolCalls")
|
||
|
||
response = AiCallResponse(
|
||
content=content,
|
||
modelName=model.name,
|
||
provider=model.connectorType,
|
||
priceCHF=priceCHF,
|
||
processingTime=processingTime,
|
||
bytesSent=inputBytes,
|
||
bytesReceived=outputBytes,
|
||
errorCount=0,
|
||
toolCalls=responseToolCalls
|
||
)
|
||
response._modelMaxTokens = model.maxTokens
|
||
|
||
if self.billingCallback:
|
||
try:
|
||
self.billingCallback(response)
|
||
except Exception as e:
|
||
logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}")
|
||
|
||
return response
|
||
|
||
async def callWithTextContextStream(
|
||
self, request: AiCallRequest
|
||
) -> AsyncGenerator[Union[str, AiCallResponse], None]:
|
||
"""Streaming variant of callWithTextContext. Yields str deltas, then final AiCallResponse."""
|
||
options = request.options
|
||
availableModels = modelRegistry.getAvailableModels()
|
||
|
||
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
||
if allowedProviders:
|
||
filtered = [m for m in availableModels if m.connectorType in allowedProviders]
|
||
if filtered:
|
||
availableModels = filtered
|
||
else:
|
||
yield AiCallResponse(
|
||
content=f"No models match allowedProviders {allowedProviders} for operation {options.operationType}",
|
||
modelName="error", priceCHF=0.0, processingTime=0.0,
|
||
bytesSent=0, bytesReceived=0, errorCount=1,
|
||
)
|
||
return
|
||
|
||
failoverModelList = modelSelector.getFailoverModelList(
|
||
request.prompt, request.context or "", options, availableModels
|
||
)
|
||
if not failoverModelList:
|
||
yield AiCallResponse(
|
||
content=f"No suitable models found for operation {options.operationType}",
|
||
modelName="error", priceCHF=0.0, processingTime=0.0,
|
||
bytesSent=0, bytesReceived=0, errorCount=1,
|
||
)
|
||
return
|
||
|
||
_MAX_SHORT_RETRY = 15.0
|
||
|
||
lastError = None
|
||
for attempt, model in enumerate(failoverModelList):
|
||
try:
|
||
logger.info(f"Streaming AI call with model: {model.name} (attempt {attempt + 1})")
|
||
async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools, toolChoice=request.toolChoice):
|
||
yield chunk
|
||
return
|
||
|
||
except RateLimitExceededException as rle:
|
||
retryAfter = rle.retryAfterSeconds
|
||
lastError = rle
|
||
if 0 < retryAfter <= _MAX_SHORT_RETRY:
|
||
logger.info(f"Rate limit on {model.name}, waiting {retryAfter:.1f}s before retry")
|
||
await asyncio.sleep(retryAfter + 0.5)
|
||
try:
|
||
async for chunk in self._callWithMessagesStream(model, request.messages, options, request.tools, toolChoice=request.toolChoice):
|
||
yield chunk
|
||
return
|
||
except Exception as retryErr:
|
||
lastError = retryErr
|
||
logger.warning(f"Retry after rate-limit wait also failed for {model.name}: {retryErr}")
|
||
else:
|
||
logger.warning(f"Rate limit on {model.name} (retryAfter={retryAfter:.1f}s), failing over")
|
||
cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0
|
||
modelSelector.reportFailure(model.name, cooldownSeconds=cooldown)
|
||
if attempt < len(failoverModelList) - 1:
|
||
continue
|
||
break
|
||
|
||
except Exception as e:
|
||
lastError = e
|
||
logger.warning(f"Streaming AI call failed with {model.name}: {e}")
|
||
modelSelector.reportFailure(model.name)
|
||
if attempt < len(failoverModelList) - 1:
|
||
continue
|
||
break
|
||
|
||
yield AiCallResponse(
|
||
content=f"All models failed (stream). Last error: {lastError}",
|
||
modelName="error", priceCHF=0.0, processingTime=0.0,
|
||
bytesSent=0, bytesReceived=0, errorCount=1,
|
||
)
|
||
|
||
async def _callWithMessagesStream(
|
||
self, model: AiModel, messages: List[Dict[str, Any]],
|
||
options: AiCallOptions = None, tools: List[Dict[str, Any]] = None,
|
||
toolChoice: Any = None,
|
||
) -> AsyncGenerator[Union[str, AiCallResponse], None]:
|
||
"""Stream a model call. Yields str deltas, then final AiCallResponse with billing."""
|
||
from modules.datamodels.datamodelAi import AiModelCall, AiModelResponse
|
||
|
||
inputBytes = sum(len(str(m.get("content", "")).encode("utf-8")) for m in messages)
|
||
startTime = time.time()
|
||
|
||
if not model.functionCallStream:
|
||
response = await self._callWithMessages(model, messages, options, tools, toolChoice=toolChoice)
|
||
if response.content:
|
||
yield response.content
|
||
yield response
|
||
return
|
||
|
||
modelCall = AiModelCall(
|
||
messages=messages, model=model,
|
||
options=options or {}, tools=tools,
|
||
toolChoice=toolChoice,
|
||
)
|
||
|
||
finalModelResponse = None
|
||
async for item in model.functionCallStream(modelCall):
|
||
if isinstance(item, AiModelResponse):
|
||
finalModelResponse = item
|
||
else:
|
||
yield item
|
||
|
||
if not finalModelResponse:
|
||
raise ValueError(f"Stream from {model.name} produced no final AiModelResponse")
|
||
|
||
endTime = time.time()
|
||
processingTime = endTime - startTime
|
||
content = finalModelResponse.content
|
||
outputBytes = len(content.encode("utf-8"))
|
||
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
|
||
|
||
responseToolCalls = None
|
||
if finalModelResponse.metadata:
|
||
responseToolCalls = finalModelResponse.metadata.get("toolCalls")
|
||
|
||
response = AiCallResponse(
|
||
content=content,
|
||
modelName=model.name,
|
||
provider=model.connectorType,
|
||
priceCHF=priceCHF,
|
||
processingTime=processingTime,
|
||
bytesSent=inputBytes,
|
||
bytesReceived=outputBytes,
|
||
errorCount=0,
|
||
toolCalls=responseToolCalls,
|
||
)
|
||
response._modelMaxTokens = model.maxTokens
|
||
|
||
if self.billingCallback:
|
||
try:
|
||
self.billingCallback(response)
|
||
except Exception as e:
|
||
logger.error(f"BILLING: Failed to record stream billing for {model.name}: {e}")
|
||
|
||
yield response
|
||
|
||
async def callEmbedding(self, texts: List[str], options: AiCallOptions = None) -> AiCallResponse:
|
||
"""Generate embeddings for a list of texts using the best available embedding model.
|
||
|
||
Token-aware batching: splits the texts list into batches that respect the
|
||
model's contextLength (with 10% safety margin). Each batch is sent as a
|
||
separate API call; the resulting embeddings are merged in order.
|
||
|
||
Failover across providers (OpenAI -> Mistral) works identically to chat models,
|
||
but ContextLengthExceededException is NOT retried via failover (same limits).
|
||
|
||
Returns:
|
||
AiCallResponse with metadata["embeddings"] containing the vectors.
|
||
"""
|
||
from modules.aicore.aicoreBase import ContextLengthExceededException as _CtxExc
|
||
|
||
if options is None:
|
||
options = AiCallOptions(operationType=OperationTypeEnum.EMBEDDING)
|
||
else:
|
||
options.operationType = OperationTypeEnum.EMBEDDING
|
||
|
||
combinedText = " ".join(texts[:3])[:500]
|
||
availableModels = modelRegistry.getAvailableModels()
|
||
|
||
allowedProviders = getattr(options, 'allowedProviders', None) if options else None
|
||
if allowedProviders:
|
||
filtered = [m for m in availableModels if m.connectorType in allowedProviders]
|
||
if filtered:
|
||
availableModels = filtered
|
||
else:
|
||
logger.warning(f"No embedding models match allowedProviders {allowedProviders}")
|
||
|
||
failoverModelList = modelSelector.getFailoverModelList(
|
||
combinedText, "", options, availableModels
|
||
)
|
||
|
||
if not failoverModelList:
|
||
return AiCallResponse(
|
||
content="", modelName="error", priceCHF=0.0,
|
||
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
|
||
)
|
||
|
||
lastError = None
|
||
for attempt, model in enumerate(failoverModelList):
|
||
try:
|
||
logger.info(f"Embedding call with {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
|
||
inputBytes = sum(len(t.encode("utf-8")) for t in texts)
|
||
startTime = time.time()
|
||
|
||
batches = _buildEmbeddingBatches(texts, model.contextLength)
|
||
logger.info(
|
||
f"Embedding: {len(texts)} texts -> {len(batches)} batch(es), "
|
||
f"model contextLength={model.contextLength}"
|
||
)
|
||
|
||
allEmbeddings: List[List[float]] = []
|
||
totalPriceCHF = 0.0
|
||
|
||
for batchIdx, batch in enumerate(batches):
|
||
modelCall = AiModelCall(
|
||
model=model, options=options, embeddingInput=batch
|
||
)
|
||
modelResponse = await model.functionCall(modelCall)
|
||
|
||
if not modelResponse.success:
|
||
raise ValueError(f"Embedding batch {batchIdx + 1} failed: {modelResponse.error}")
|
||
|
||
batchEmbeddings = (modelResponse.metadata or {}).get("embeddings", [])
|
||
allEmbeddings.extend(batchEmbeddings)
|
||
|
||
batchBytes = sum(len(t.encode("utf-8")) for t in batch)
|
||
totalPriceCHF += model.calculatepriceCHF(0, batchBytes, 0)
|
||
|
||
processingTime = time.time() - startTime
|
||
if totalPriceCHF == 0.0:
|
||
totalPriceCHF = model.calculatepriceCHF(processingTime, inputBytes, 0)
|
||
|
||
response = AiCallResponse(
|
||
content="", modelName=model.name, provider=model.connectorType,
|
||
priceCHF=totalPriceCHF, processingTime=processingTime,
|
||
bytesSent=inputBytes, bytesReceived=0, errorCount=0,
|
||
metadata={"embeddings": allEmbeddings}
|
||
)
|
||
|
||
if self.billingCallback:
|
||
try:
|
||
self.billingCallback(response)
|
||
except Exception as e:
|
||
logger.error(f"BILLING: Failed to record billing for embedding {model.name}: {e}")
|
||
|
||
return response
|
||
|
||
except _CtxExc as e:
|
||
logger.error(f"ContextLengthExceeded for {model.name} despite batching – aborting failover: {e}")
|
||
return AiCallResponse(
|
||
content=str(e), modelName=model.name, priceCHF=0.0,
|
||
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
|
||
)
|
||
|
||
except RateLimitExceededException as rle:
|
||
retryAfter = rle.retryAfterSeconds
|
||
lastError = rle
|
||
cooldown = max(retryAfter, 10.0) if retryAfter > 0 else 0.0
|
||
logger.warning(f"Rate limit on {model.name} during embedding (retryAfter={retryAfter:.1f}s)")
|
||
modelSelector.reportFailure(model.name, cooldownSeconds=cooldown)
|
||
if attempt < len(failoverModelList) - 1:
|
||
continue
|
||
break
|
||
|
||
except Exception as e:
|
||
lastError = e
|
||
logger.warning(f"Embedding call failed with {model.name}: {str(e)}")
|
||
modelSelector.reportFailure(model.name)
|
||
if attempt < len(failoverModelList) - 1:
|
||
continue
|
||
break
|
||
|
||
errorMsg = f"All embedding models failed. Last error: {str(lastError)}"
|
||
logger.error(errorMsg)
|
||
return AiCallResponse(
|
||
content=errorMsg, modelName="error", priceCHF=0.0,
|
||
processingTime=0.0, bytesSent=0, bytesReceived=0, errorCount=1
|
||
)
|
||
|
||
# Utility methods
|
||
async def listAvailableModels(self, connectorType: str = None) -> List[Dict[str, Any]]:
|
||
"""List available models, optionally filtered by connector type."""
|
||
models = modelRegistry.getAvailableModels()
|
||
if connectorType:
|
||
return [model.model_dump() for model in models if model.connectorType == connectorType]
|
||
return [model.model_dump() for model in models]
|
||
|
||
async def getModelInfo(self, displayName: str) -> Dict[str, Any]:
|
||
"""Get information about a specific model by displayName."""
|
||
model = modelRegistry.getModel(displayName)
|
||
if not model:
|
||
raise ValueError(f"Model with displayName '{displayName}' not found")
|
||
return model.model_dump()
|
||
|
||
async def getModelsByTag(self, tag: str) -> List[str]:
|
||
"""Get model displayNames that have a specific tag. Returns displayNames (unique identifiers)."""
|
||
models = modelRegistry.getModelsByTag(tag)
|
||
return [model.displayName for model in models]
|
||
|
||
|
||
# =============================================================================
|
||
# Internal helpers
|
||
# =============================================================================
|
||
|
||
_CHARS_PER_TOKEN = 4
|
||
_SAFETY_MARGIN = 0.90
|
||
|
||
|
||
def _estimateTokens(text: str) -> int:
|
||
"""Rough token estimate: 1 token ~ 4 characters."""
|
||
return max(1, len(text) // _CHARS_PER_TOKEN)
|
||
|
||
|
||
def _buildEmbeddingBatches(texts: List[str], contextLength: int) -> List[List[str]]:
|
||
"""Split a list of texts into batches whose total estimated token count
|
||
stays within the model's contextLength (with safety margin).
|
||
|
||
Each individual text is assumed to already be within limits (enforced by
|
||
the chunking layer). If a single text exceeds the budget, it is placed
|
||
in its own batch as a last resort.
|
||
"""
|
||
if not texts:
|
||
return []
|
||
if contextLength <= 0:
|
||
return [texts]
|
||
|
||
maxTokensPerBatch = int(contextLength * _SAFETY_MARGIN)
|
||
batches: List[List[str]] = []
|
||
currentBatch: List[str] = []
|
||
currentTokens = 0
|
||
|
||
for text in texts:
|
||
textTokens = _estimateTokens(text)
|
||
if currentBatch and (currentTokens + textTokens) > maxTokensPerBatch:
|
||
batches.append(currentBatch)
|
||
currentBatch = []
|
||
currentTokens = 0
|
||
currentBatch.append(text)
|
||
currentTokens += textTokens
|
||
|
||
if currentBatch:
|
||
batches.append(currentBatch)
|
||
|
||
return batches
|
||
|
||
|
||
|