fix model registration race locker
This commit is contained in:
parent
564a1200c6
commit
60d5062204
1 changed files with 54 additions and 40 deletions
|
|
@ -38,6 +38,31 @@ class ModelRegistry:
|
||||||
self._getAvailableModelsCache: Dict[Tuple[str, int], Tuple[List[AiModel], float]] = {} # (user_id, rbac_id) -> (models, ts)
|
self._getAvailableModelsCache: Dict[Tuple[str, int], Tuple[List[AiModel], float]] = {} # (user_id, rbac_id) -> (models, ts)
|
||||||
self._getAvailableModelsCacheTtl: float = 30.0 # seconds
|
self._getAvailableModelsCacheTtl: float = 30.0 # seconds
|
||||||
|
|
||||||
|
def _addModelToDict(self, model: AiModel, connectorType: str, target: Dict[str, AiModel]):
|
||||||
|
"""Add model to a dict, tolerating benign re-adds from the same connector."""
|
||||||
|
if model.displayName in target:
|
||||||
|
existing = target[model.displayName]
|
||||||
|
if existing.name == model.name and existing.connectorType == model.connectorType:
|
||||||
|
logger.debug(f"Skipping duplicate model '{model.displayName}' from same connector {connectorType}")
|
||||||
|
return
|
||||||
|
raise ValueError(
|
||||||
|
f"displayName conflict '{model.displayName}': "
|
||||||
|
f"existing name='{existing.name}' (connector: {existing.connectorType}), "
|
||||||
|
f"new name='{model.name}' (connector: {connectorType})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE:
|
||||||
|
originalMaxTokens = model.maxTokens
|
||||||
|
model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE
|
||||||
|
logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}")
|
||||||
|
|
||||||
|
target[model.displayName] = model
|
||||||
|
logger.debug(f"Registered model: {model.displayName} (name: {model.name}) from {connectorType}")
|
||||||
|
|
||||||
|
def _addModel(self, model: AiModel, connectorType: str):
|
||||||
|
"""Convenience wrapper for adding to self._models."""
|
||||||
|
self._addModelToDict(model, connectorType, self._models)
|
||||||
|
|
||||||
def registerConnector(self, connector: BaseConnectorAi):
|
def registerConnector(self, connector: BaseConnectorAi):
|
||||||
"""Register a connector and collect its models."""
|
"""Register a connector and collect its models."""
|
||||||
connectorType = connector.getConnectorType()
|
connectorType = connector.getConnectorType()
|
||||||
|
|
@ -102,51 +127,40 @@ class ModelRegistry:
|
||||||
self._connectorsInitialized = True
|
self._connectorsInitialized = True
|
||||||
|
|
||||||
def refreshModels(self, force: bool = False):
|
def refreshModels(self, force: bool = False):
|
||||||
"""Refresh models from all registered connectors."""
|
"""Refresh models from all registered connectors. Thread-safe via _refreshLock."""
|
||||||
import time
|
|
||||||
|
|
||||||
self.ensureConnectorsRegistered()
|
self.ensureConnectorsRegistered()
|
||||||
|
|
||||||
currentTime = time.time()
|
currentTime = time.time()
|
||||||
|
|
||||||
# Check if refresh is needed
|
|
||||||
if (not force and
|
if (not force and
|
||||||
self._lastRefresh is not None and
|
self._lastRefresh is not None and
|
||||||
currentTime - self._lastRefresh < self._refreshInterval):
|
currentTime - self._lastRefresh < self._refreshInterval):
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Refreshing model registry...")
|
if not self._refreshLock.acquire(blocking=False):
|
||||||
|
logger.debug("refreshModels already running in another thread, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
# Clear existing models
|
try:
|
||||||
self._models.clear()
|
logger.info("Refreshing model registry...")
|
||||||
|
newModels: Dict[str, AiModel] = {}
|
||||||
|
|
||||||
# Re-register all connectors
|
for connector in self._connectors.values():
|
||||||
for connector in self._connectors.values():
|
connectorType = connector.getConnectorType()
|
||||||
try:
|
try:
|
||||||
connector.clearCache() # Clear connector cache
|
connector.clearCache()
|
||||||
models = connector.getCachedModels()
|
models = connector.getCachedModels()
|
||||||
for model in models:
|
for model in models:
|
||||||
# Validate displayName uniqueness
|
self._addModelToDict(model, connectorType, newModels)
|
||||||
if model.displayName in self._models:
|
except Exception as e:
|
||||||
existingModel = self._models[model.displayName]
|
logger.error(f"Failed to refresh models from {connectorType}: {e}")
|
||||||
errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connector.getConnectorType()}). displayName must be unique."
|
raise
|
||||||
logger.error(errorMsg)
|
|
||||||
raise ValueError(errorMsg)
|
|
||||||
|
|
||||||
# TODO TESTING: Override maxTokens if testing override is enabled
|
self._models = newModels
|
||||||
if TESTING_MAX_TOKENS_OVERRIDE is not None and model.maxTokens > TESTING_MAX_TOKENS_OVERRIDE:
|
self._lastRefresh = time.time()
|
||||||
originalMaxTokens = model.maxTokens
|
logger.info(f"Model registry refreshed: {len(self._models)} models available")
|
||||||
model.maxTokens = TESTING_MAX_TOKENS_OVERRIDE
|
finally:
|
||||||
logger.debug(f"TESTING: Overrode maxTokens for {model.displayName}: {originalMaxTokens} -> {TESTING_MAX_TOKENS_OVERRIDE}")
|
self._refreshLock.release()
|
||||||
|
|
||||||
# Use displayName as the key (must be unique)
|
|
||||||
self._models[model.displayName] = model
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to refresh models from {connector.getConnectorType()}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
self._lastRefresh = currentTime
|
|
||||||
logger.info(f"Model registry refreshed: {len(self._models)} models available")
|
|
||||||
|
|
||||||
def getModel(self, displayName: str) -> Optional[AiModel]:
|
def getModel(self, displayName: str) -> Optional[AiModel]:
|
||||||
"""Get a specific model by displayName (displayName must be unique)."""
|
"""Get a specific model by displayName (displayName must be unique)."""
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue