152 lines
5.6 KiB
Python
152 lines
5.6 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Base connector interface for AI connectors.
|
|
All AI connectors should inherit from this class.
|
|
|
|
IMPORTANT: Model Registration Requirements
|
|
- Each model must have a unique displayName across all connectors
|
|
- The displayName is used as the unique identifier in the model registry
|
|
- The name field is used for API calls (can be duplicated across different model instances)
|
|
- If duplicate displayNames are detected during registration, an error will be raised
|
|
"""
|
|
|
|
import re as _re
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Dict, Any, Optional, AsyncGenerator, Union
|
|
from modules.datamodels.datamodelAi import AiModel, AiModelCall, AiModelResponse
|
|
|
|
|
|
_RETRY_AFTER_PATTERN = _re.compile(
|
|
r"(?:try again in|retry after)\s+(\d+(?:\.\d+)?)\s*s", _re.IGNORECASE
|
|
)
|
|
|
|
|
|
def _parseRetryAfterSeconds(message: str) -> float:
|
|
"""Extract retry-after seconds from provider error messages like 'Please try again in 6.558s'."""
|
|
match = _RETRY_AFTER_PATTERN.search(message)
|
|
return float(match.group(1)) if match else 0.0
|
|
|
|
|
|
class RateLimitExceededException(Exception):
|
|
"""Raised when a provider's rate limit (TPM / RPM) is exceeded."""
|
|
def __init__(self, message: str = "Rate limit exceeded", retryAfterSeconds: float = 0.0):
|
|
super().__init__(message)
|
|
if retryAfterSeconds <= 0:
|
|
retryAfterSeconds = _parseRetryAfterSeconds(message)
|
|
self.retryAfterSeconds = retryAfterSeconds
|
|
|
|
|
|
class ContextLengthExceededException(Exception):
|
|
"""Raised when the input exceeds a model's context window."""
|
|
pass
|
|
|
|
|
|
class BaseConnectorAi(ABC):
|
|
"""
|
|
Base class for all AI connectors.
|
|
|
|
IMPORTANT: Models returned by getModels() must have unique displayName values.
|
|
The displayName serves as the unique identifier in the model registry.
|
|
Duplicate displayNames will cause registration to fail with an error.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._models_cache: Optional[List[AiModel]] = None
|
|
self._last_cache_update: Optional[float] = None
|
|
self._cache_ttl: float = 300.0 # 5 minutes cache TTL
|
|
|
|
@abstractmethod
|
|
def getModels(self) -> List[AiModel]:
|
|
"""
|
|
Get all available models for this connector.
|
|
Should be implemented by each connector.
|
|
|
|
IMPORTANT: Each model's displayName must be unique across all connectors.
|
|
If multiple models share the same API name (e.g., "gpt-4o"), they must have
|
|
different displayNames (e.g., "OpenAI GPT-4o" vs "OpenAI GPT-4o Instance Vision").
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def getConnectorType(self) -> str:
|
|
"""
|
|
Get the connector type identifier.
|
|
Should return one of: openai, anthropic, perplexity, tavily
|
|
"""
|
|
pass
|
|
|
|
def getCachedModels(self) -> List[AiModel]:
|
|
"""
|
|
Get cached models with TTL check.
|
|
Returns cached models if still valid, otherwise refreshes cache.
|
|
"""
|
|
import time
|
|
|
|
current_time = time.time()
|
|
|
|
# Check if cache is valid
|
|
if (self._models_cache is not None and
|
|
self._last_cache_update is not None and
|
|
current_time - self._last_cache_update < self._cache_ttl):
|
|
return self._models_cache
|
|
|
|
# Refresh cache
|
|
self._models_cache = self.getModels()
|
|
self._last_cache_update = current_time
|
|
|
|
return self._models_cache
|
|
|
|
def clearCache(self):
|
|
"""Clear the models cache."""
|
|
self._models_cache = None
|
|
self._last_cache_update = None
|
|
|
|
def getModelByDisplayName(self, displayName: str) -> Optional[AiModel]:
|
|
"""Get a specific model by displayName (displayName must be unique)."""
|
|
models = self.getCachedModels()
|
|
for model in models:
|
|
if model.displayName == displayName:
|
|
return model
|
|
return None
|
|
|
|
def getModelByName(self, name: str) -> Optional[AiModel]:
|
|
"""Get a specific model by name (API name). Note: name can be duplicated, returns first match."""
|
|
models = self.getCachedModels()
|
|
for model in models:
|
|
if model.name == name:
|
|
return model
|
|
return None
|
|
|
|
|
|
def getModelsByPriority(self, priority: str) -> List[AiModel]:
|
|
"""Get models that have a specific priority."""
|
|
models = self.getCachedModels()
|
|
return [model for model in models if model.priority == priority]
|
|
|
|
def getAvailableModels(self) -> List[AiModel]:
|
|
"""Get only available models."""
|
|
models = self.getCachedModels()
|
|
return [model for model in models if model.isAvailable]
|
|
|
|
async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]:
|
|
"""Stream AI response. Yields str deltas during generation, then final AiModelResponse.
|
|
|
|
Default implementation: falls back to non-streaming callAiBasic.
|
|
Override in connectors that support streaming.
|
|
"""
|
|
response = await self.callAiBasic(modelCall)
|
|
if response.content:
|
|
yield response.content
|
|
yield response
|
|
|
|
async def callEmbedding(self, modelCall: AiModelCall) -> AiModelResponse:
|
|
"""Generate embeddings for input texts. Override in connectors that support embeddings.
|
|
|
|
Reads texts from modelCall.embeddingInput.
|
|
Returns AiModelResponse with metadata["embeddings"] containing the vectors.
|
|
"""
|
|
raise NotImplementedError(
|
|
f"{self.__class__.__name__} does not support embeddings"
|
|
)
|