# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Base connector interface for AI connectors. All AI connectors should inherit from this class. IMPORTANT: Model Registration Requirements - Each model must have a unique displayName across all connectors - The displayName is used as the unique identifier in the model registry - The name field is used for API calls (can be duplicated across different model instances) - If duplicate displayNames are detected during registration, an error will be raised """ from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional, AsyncGenerator, Union from modules.datamodels.datamodelAi import AiModel, AiModelCall, AiModelResponse class BaseConnectorAi(ABC): """ Base class for all AI connectors. IMPORTANT: Models returned by getModels() must have unique displayName values. The displayName serves as the unique identifier in the model registry. Duplicate displayNames will cause registration to fail with an error. """ def __init__(self): self._models_cache: Optional[List[AiModel]] = None self._last_cache_update: Optional[float] = None self._cache_ttl: float = 300.0 # 5 minutes cache TTL @abstractmethod def getModels(self) -> List[AiModel]: """ Get all available models for this connector. Should be implemented by each connector. IMPORTANT: Each model's displayName must be unique across all connectors. If multiple models share the same API name (e.g., "gpt-4o"), they must have different displayNames (e.g., "OpenAI GPT-4o" vs "OpenAI GPT-4o Instance Vision"). """ pass @abstractmethod def getConnectorType(self) -> str: """ Get the connector type identifier. Should return one of: openai, anthropic, perplexity, tavily """ pass def getCachedModels(self) -> List[AiModel]: """ Get cached models with TTL check. Returns cached models if still valid, otherwise refreshes cache. """ import time current_time = time.time() # Check if cache is valid if (self._models_cache is not None and self._last_cache_update is not None and current_time - self._last_cache_update < self._cache_ttl): return self._models_cache # Refresh cache self._models_cache = self.getModels() self._last_cache_update = current_time return self._models_cache def clearCache(self): """Clear the models cache.""" self._models_cache = None self._last_cache_update = None def getModelByDisplayName(self, displayName: str) -> Optional[AiModel]: """Get a specific model by displayName (displayName must be unique).""" models = self.getCachedModels() for model in models: if model.displayName == displayName: return model return None def getModelByName(self, name: str) -> Optional[AiModel]: """Get a specific model by name (API name). Note: name can be duplicated, returns first match.""" models = self.getCachedModels() for model in models: if model.name == name: return model return None def getModelsByPriority(self, priority: str) -> List[AiModel]: """Get models that have a specific priority.""" models = self.getCachedModels() return [model for model in models if model.priority == priority] def getAvailableModels(self) -> List[AiModel]: """Get only available models.""" models = self.getCachedModels() return [model for model in models if model.isAvailable] async def callAiBasicStream(self, modelCall: AiModelCall) -> AsyncGenerator[Union[str, AiModelResponse], None]: """Stream AI response. Yields str deltas during generation, then final AiModelResponse. Default implementation: falls back to non-streaming callAiBasic. Override in connectors that support streaming. """ response = await self.callAiBasic(modelCall) if response.content: yield response.content yield response async def callEmbedding(self, modelCall: AiModelCall) -> AiModelResponse: """Generate embeddings for input texts. Override in connectors that support embeddings. Reads texts from modelCall.embeddingInput. Returns AiModelResponse with metadata["embeddings"] containing the vectors. """ raise NotImplementedError( f"{self.__class__.__name__} does not support embeddings" )