gateway/modules/features/chatbot/bridges/ai.py
Ida Dittrich 6dc2afafb9 fix:performance improvements
- app.py: Pre-warm AI connectors at module load and in lifespan
- aicoreModelRegistry.py: Connector discovery cache, getAvailableModels cache, bulk RBAC, eager prewarm
- connectorDbPostgre.py: Connector cache, contextvars for userId, eviction (max 32)
- chatbot: Uses _get_cached_connector, Service center integration, BillingService exceptions, BillingService exceptions instead of direct imports
- interfaceDbApp.py: Uses _get_cached_connector
- interfaceDbManagement.py: Uses _get_cached_connector
- security/rbac.py: Adds checkResourceAccessBulk
2026-03-06 13:46:54 +01:00

731 lines
32 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
AI Center to LangChain bridge.
Implements LangChain BaseChatModel interface using AI center models.
"""
import logging
import asyncio
import time
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
from datetime import datetime
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
AIMessage,
ToolMessage,
convert_to_openai_messages,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import RunnableConfig
from modules.aicore.aicoreModelRegistry import modelRegistry
from modules.aicore.aicoreModelSelector import modelSelector
from modules.datamodels.datamodelAi import (
AiModel,
AiModelCall,
AiModelResponse,
AiCallResponse,
AiCallOptions,
OperationTypeEnum,
ProcessingModeEnum,
)
from modules.datamodels.datamodelUam import User
logger = logging.getLogger(__name__)
# Workflow-level store for allowed_providers and RBAC context (survives LangGraph/bind_tools
# execution context where instance attributes may be lost when model is wrapped or serialized)
_workflow_allowed_providers: Dict[str, List[str]] = {}
_workflow_rbac_context: Dict[str, tuple] = {} # workflow_id -> (mandateId, featureInstanceId)
def clear_workflow_allowed_providers(workflow_id: str) -> None:
"""Remove workflow from registry when stream completes to avoid memory growth."""
_workflow_allowed_providers.pop(workflow_id, None)
class AICenterChatModel(BaseChatModel):
"""
LangChain-compatible chat model that uses AI center models.
Bridges AI center model selection and calling to LangChain's BaseChatModel interface.
"""
def __init__(
self,
user: User,
operation_type: OperationTypeEnum = OperationTypeEnum.DATA_ANALYSE,
processing_mode: ProcessingModeEnum = ProcessingModeEnum.DETAILED,
billing_callback: Optional[Callable[[AiCallResponse], None]] = None,
workflow_id: Optional[str] = None,
allowed_providers: Optional[List[str]] = None,
prefer_fast_model: bool = False,
mandate_id: Optional[str] = None,
feature_instance_id: Optional[str] = None,
**kwargs
):
"""
Initialize the AI center chat model bridge.
Args:
user: Current user for RBAC and model selection
operation_type: Operation type for model selection
processing_mode: Processing mode for model selection
billing_callback: Optional callback invoked after each _agenerate with AiCallResponse for billing
workflow_id: Optional workflow/conversation ID for billing context
allowed_providers: Optional list of allowed provider connector types (empty/None = all)
prefer_fast_model: When True, strongly prefer faster models (e.g. gpt-4o-mini for planner)
**kwargs: Additional arguments passed to BaseChatModel
"""
super().__init__(**kwargs)
# Use object.__setattr__ to bypass Pydantic validation for custom attributes
object.__setattr__(self, "user", user)
object.__setattr__(self, "operation_type", operation_type)
object.__setattr__(self, "processing_mode", processing_mode)
object.__setattr__(self, "_selected_model", None)
object.__setattr__(self, "_billing_callback", billing_callback)
object.__setattr__(self, "_workflow_id", workflow_id)
object.__setattr__(self, "_allowed_providers", allowed_providers or [])
object.__setattr__(self, "_prefer_fast_model", prefer_fast_model)
object.__setattr__(self, "_mandate_id", mandate_id)
object.__setattr__(self, "_feature_instance_id", feature_instance_id)
# Store in workflow-level registry so it survives when instance attrs are lost (e.g. bind_tools)
if workflow_id and allowed_providers:
_workflow_allowed_providers[workflow_id] = list(allowed_providers)
if workflow_id and (mandate_id is not None or feature_instance_id is not None):
_workflow_rbac_context[workflow_id] = (mandate_id, feature_instance_id)
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
return "aicenter"
def _select_model(self, messages: List[BaseMessage]) -> AiModel:
"""
Select the best AI center model for the given messages.
Uses caching to avoid repeated model selection within same session.
Args:
messages: List of LangChain messages
Returns:
Selected AI model
"""
# Return cached model if already selected (significant performance improvement)
if self._selected_model is not None:
return self._selected_model
# Convert messages to prompt/context format for model selector
prompt_parts = []
context_parts = []
for msg in messages:
if isinstance(msg, SystemMessage):
prompt_parts.append(msg.content)
elif isinstance(msg, HumanMessage):
prompt_parts.append(msg.content)
elif isinstance(msg, AIMessage):
context_parts.append(msg.content)
elif isinstance(msg, ToolMessage):
context_parts.append(f"Tool {msg.name}: {msg.content}")
prompt = "\n".join(prompt_parts)
context = "\n".join(context_parts) if context_parts else ""
# Get available models with RBAC filtering
# Use cached/singleton interfaces for better performance
from modules.interfaces.interfaceDbApp import getRootInterface
workflow_id = getattr(self, "_workflow_id", None)
rootInterface = getRootInterface()
rbac_instance = rootInterface.rbac
mandate_id = getattr(self, "_mandate_id", None)
feature_instance_id = getattr(self, "_feature_instance_id", None)
if workflow_id and (mandate_id is None and feature_instance_id is None):
ctx = _workflow_rbac_context.get(workflow_id)
if ctx:
mandate_id, feature_instance_id = ctx
available_models = modelRegistry.getAvailableModels(
currentUser=self.user,
rbacInstance=rbac_instance,
mandateId=mandate_id,
featureInstanceId=feature_instance_id,
)
# Allowed providers: instance attr or workflow store (lost in LangGraph/bind_tools context)
allowed = (
(_workflow_allowed_providers.get(workflow_id) if workflow_id else None)
or getattr(self, '_allowed_providers', None)
or []
)
if allowed:
logger.info(f"AICenterChatModel _select_model: applying allowedProviders={allowed}")
filtered = [m for m in available_models if m.connectorType in allowed]
if filtered:
available_models = filtered
else:
logger.warning(f"No models match allowedProviders {allowed}, using all RBAC-permitted models")
options = AiCallOptions(
operationType=self.operation_type,
processingMode=self.processing_mode,
allowedProviders=allowed if allowed else None,
preferFastModel=getattr(self, "_prefer_fast_model", False),
)
# Select model
selected_model = modelSelector.selectModel(
prompt=prompt,
context=context,
options=options,
availableModels=available_models
)
if not selected_model:
raise ValueError(f"No suitable model found for operation type {self.operation_type.value}")
logger.info(f"Selected AI center model: {selected_model.displayName} ({selected_model.name})")
object.__setattr__(self, "_selected_model", selected_model)
return selected_model
def _convert_messages_to_ai_format(self, messages: List[BaseMessage]) -> List[Dict[str, Any]]:
"""
Convert LangChain messages to AI center format (OpenAI-style).
Args:
messages: List of LangChain messages
Returns:
List of messages in OpenAI format
"""
# Use LangChain's built-in conversion
openai_messages = convert_to_openai_messages(messages)
return openai_messages
def _convert_ai_response_to_langchain(
self,
response: AiModelResponse,
tool_calls: Optional[List[Dict[str, Any]]] = None
) -> AIMessage:
"""
Convert AI center response to LangChain AIMessage.
Args:
response: AI center response
tool_calls: Optional tool calls from the response (format: [{"id": "...", "name": "...", "args": {...}}])
Returns:
LangChain AIMessage with tool_calls if present
"""
# LangChain expects tool_calls in format: [{"id": "...", "name": "...", "args": {...}}]
# The tool_calls parameter should already be in this format
kwargs = {}
if tool_calls:
kwargs["tool_calls"] = tool_calls
return AIMessage(content=response.content or "", **kwargs)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[Any] = None,
**kwargs: Any,
) -> ChatResult:
"""
Synchronous generate method required by BaseChatModel.
Wraps the async _agenerate method.
Args:
messages: List of LangChain messages
stop: Optional stop sequences
run_manager: Optional callback manager
**kwargs: Additional arguments
Returns:
ChatResult with generations
"""
# Try to get the current event loop
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# If we're in an async context, raise an error
raise RuntimeError(
"AICenterChatModel._generate() called from async context. "
"Use _agenerate() instead."
)
except RuntimeError:
# No event loop, we can create one
pass
# Run the async method synchronously
return asyncio.run(self._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs))
async def _call_openai_streaming(
self,
ai_messages: List[dict],
run_manager: Optional[Any],
model_call: "AiModelCall",
input_bytes: int,
start_time: float,
) -> "AiModelResponse":
"""Call OpenAI/Ollama with stream=True, emit tokens via run_manager, return full response."""
import httpx
import json as _json
from modules.shared.configuration import APP_CONFIG
if self._selected_model.connectorType == "openai":
api_url = getattr(self._selected_model, "apiUrl", None) or "https://api.openai.com/v1/chat/completions"
api_key = APP_CONFIG.get("Connector_AiOpenai_API_SECRET")
if not api_key:
raise ValueError("OpenAI API key not configured")
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
ollama_model = self._selected_model.name
else:
base_url = getattr(self._selected_model, "apiUrl", "").replace("/api/analyze", "")
api_url = f"{base_url.rstrip('/')}/v1/chat/completions"
api_key = APP_CONFIG.get("Connector_AiPrivateLlm_API_SECRET")
headers = {"Content-Type": "application/json"}
if api_key:
headers["X-API-Key"] = api_key
ollama_model = getattr(self._selected_model, "version", None) or self._selected_model.name
payload = {
"model": ollama_model,
"messages": ai_messages,
"temperature": self._selected_model.temperature,
"max_tokens": self._selected_model.maxTokens,
"stream": True,
}
content_parts: List[str] = []
async with httpx.AsyncClient(timeout=600.0) as client:
async with client.stream("POST", api_url, headers=headers, json=payload) as resp:
if resp.status_code != 200:
raise ValueError(f"OpenAI stream error: {resp.status_code} - {await resp.aread()}")
buffer = ""
async for chunk in resp.aiter_text():
buffer += chunk
while "\n" in buffer or "\r\n" in buffer:
line, _, buffer = buffer.partition("\n")
line = line.strip()
if line.startswith("data: "):
data_str = line[6:].strip()
if data_str == "[DONE]":
break
try:
data = _json.loads(data_str)
choices = data.get("choices") or []
if choices:
delta = choices[0].get("delta") or {}
token = delta.get("content") or ""
if token and run_manager and hasattr(run_manager, "on_llm_new_token"):
run_manager.on_llm_new_token(token)
content_parts.append(token)
except _json.JSONDecodeError:
pass
content = "".join(content_parts)
processing_time = time.time() - start_time
output_bytes = len(content.encode("utf-8"))
price_chf = 0.0
if getattr(self._selected_model, "calculatepriceCHF", None):
try:
price_chf = self._selected_model.calculatepriceCHF(processing_time, input_bytes, output_bytes)
except Exception:
pass
billing_callback = getattr(self, "_billing_callback", None)
if billing_callback:
try:
billing_callback(AiCallResponse(
content=content,
modelName=self._selected_model.name,
provider=self._selected_model.connectorType or "unknown",
priceCHF=price_chf,
processingTime=processing_time,
bytesSent=input_bytes,
bytesReceived=output_bytes,
errorCount=0,
))
except Exception as e:
logger.error(f"Billing callback error: {e}")
return AiModelResponse(content=content, success=True, modelId=self._selected_model.name, metadata={})
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[Any] = None,
**kwargs: Any,
) -> ChatResult:
"""
Async generate method required by BaseChatModel.
Args:
messages: List of LangChain messages
stop: Optional stop sequences
run_manager: Optional callback manager
**kwargs: Additional arguments (may include tools for tool calling)
Returns:
ChatResult with generations
"""
# Select model if not already selected
if not self._selected_model:
self._select_model(messages)
# Check if tools are bound (for tool calling)
tools = getattr(self, "_bound_tools", None)
# Convert messages to AI center format
ai_messages = self._convert_messages_to_ai_format(messages)
# Compute input bytes for billing (sum of message content lengths)
input_bytes = sum(
len((m.get("content") or "").encode("utf-8"))
for m in ai_messages
if isinstance(m.get("content"), str)
)
start_time = time.time()
# If tools are bound, add tool definitions to the system message
# This ensures the model knows about available tools
# Some models need explicit tool definitions to enable tool calling
if tools:
# Find or create system message
system_message_idx = None
for i, msg in enumerate(ai_messages):
if msg.get("role") == "system":
system_message_idx = i
break
# Build tool descriptions for the system message
tool_descriptions = []
for tool in tools:
if hasattr(tool, "name") and hasattr(tool, "description"):
# Get tool parameters for better description
args_schema = getattr(tool, "args_schema", None)
params_info = ""
if args_schema:
try:
if hasattr(args_schema, "model_json_schema"):
schema = args_schema.model_json_schema()
if "properties" in schema:
params = list(schema["properties"].keys())
params_info = f" (Parameter: {', '.join(params)})"
except:
pass
tool_descriptions.append(f"- {tool.name}: {tool.description}{params_info}")
if tool_descriptions:
tools_text = "\n".join(tool_descriptions)
tools_note = f"\n\n⚠️⚠️⚠️ KRITISCH - TOOL-NUTZUNG ⚠️⚠️⚠️\n\nVERFÜGBARE TOOLS:\n{tools_text}\n\nABSOLUT VERBINDLICH:\n- Du MUSST diese Tools verwenden, um Anfragen zu bearbeiten!\n- Für Status-Updates MUSST du IMMER das Tool 'send_streaming_message' verwenden!\n- VERBOTEN: Normale Text-Nachrichten für Status-Updates!\n- Du MUSST Tools aufrufen, nicht nur darüber sprechen!\n\nBeispiel FALSCH: \"Ich werde die Datenbank durchsuchen...\"\nBeispiel RICHTIG: Rufe das Tool 'send_streaming_message' mit \"Durchsuche Datenbank...\" auf!"
if system_message_idx is not None:
# Append to existing system message
ai_messages[system_message_idx]["content"] += tools_note
else:
# Add new system message at the beginning
ai_messages.insert(0, {
"role": "system",
"content": tools_note.strip()
})
# Convert LangChain tools to OpenAI/function-calling format (used by OpenAI and Ollama-compatible APIs)
openai_tools = None
if tools and self._selected_model.connectorType in ("openai", "privatellm"):
# Build tool schema in OpenAI format (Ollama uses same format for tool calling)
openai_tools = []
for tool in tools:
if hasattr(tool, "name") and hasattr(tool, "description"):
# Get tool parameters schema
args_schema = getattr(tool, "args_schema", None)
parameters = {}
if args_schema:
# Check if it's a Pydantic model class or instance
from pydantic import BaseModel
# Check if it's a class (not an instance)
if isinstance(args_schema, type) and issubclass(args_schema, BaseModel):
# It's a Pydantic model class - get JSON schema
if hasattr(args_schema, "model_json_schema"):
# Pydantic v2
parameters = args_schema.model_json_schema()
elif hasattr(args_schema, "schema"):
# Pydantic v1
parameters = args_schema.schema()
elif isinstance(args_schema, BaseModel):
# It's a Pydantic model instance
if hasattr(args_schema, "model_dump"):
# Pydantic v2
parameters = args_schema.model_dump()
elif hasattr(args_schema, "dict"):
# Pydantic v1
parameters = args_schema.dict()
elif hasattr(args_schema, "schema"):
# Has schema method (might be a class)
try:
parameters = args_schema.schema()
except TypeError:
# If schema() requires instance, try model_json_schema
if hasattr(args_schema, "model_json_schema"):
parameters = args_schema.model_json_schema()
else:
parameters = {}
elif isinstance(args_schema, dict):
# Already a dict
parameters = args_schema
tool_schema = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": parameters
}
}
openai_tools.append(tool_schema)
# Store tools for potential use by connector
# Note: The connector may need to access tools from the model_call
# This is a workaround since AiModelCall doesn't have a tools field
# Tools are added to system message above to ensure model knows about them
# Create model call
model_call = AiModelCall(
messages=ai_messages,
model=self._selected_model,
options=AiCallOptions(
operationType=self.operation_type,
processingMode=self.processing_mode,
temperature=self._selected_model.temperature
)
)
# If tools are bound, use OpenAI-compatible API (OpenAI or Private-LLM Ollama endpoint)
if openai_tools and self._selected_model.connectorType in ("openai", "privatellm"):
import httpx
import json as _json
from modules.shared.configuration import APP_CONFIG
if self._selected_model.connectorType == "openai":
api_url = self._selected_model.apiUrl
api_key = APP_CONFIG.get("Connector_AiOpenai_API_SECRET")
if not api_key:
raise ValueError("OpenAI API key not configured")
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
ollama_model = self._selected_model.name
else:
# privatellm: use Ollama OpenAI-compatible /v1/chat/completions (same service, same provider)
base_url = self._selected_model.apiUrl.replace("/api/analyze", "")
api_url = f"{base_url}/v1/chat/completions"
api_key = APP_CONFIG.get("Connector_AiPrivateLlm_API_SECRET")
headers = {"Content-Type": "application/json"}
if api_key:
headers["X-API-Key"] = api_key
# Ollama needs the underlying model name (e.g. qwen2.5:7b), not poweron-text-general
ollama_model = getattr(self._selected_model, "version", None) or self._selected_model.name
payload = {
"model": ollama_model,
"messages": ai_messages,
"tools": openai_tools,
"tool_choice": "auto",
"temperature": self._selected_model.temperature,
"max_tokens": self._selected_model.maxTokens,
}
use_connector_fallback = False
async with httpx.AsyncClient(timeout=600.0) as client:
response_obj = await client.post(api_url, headers=headers, json=payload)
if response_obj.status_code == 404 and self._selected_model.connectorType == "privatellm":
logger.warning(
"Private-LLM /v1/chat/completions not found (404). Falling back to /api/analyze. "
"Tool calling will not work until the service exposes an OpenAI-compatible endpoint."
)
use_connector_fallback = True
elif response_obj.status_code != 200:
error_msg = f"AI API error ({self._selected_model.connectorType}): {response_obj.status_code} - {response_obj.text}"
logger.error(error_msg)
raise ValueError(error_msg)
if use_connector_fallback:
if not self._selected_model.functionCall:
raise ValueError(f"Model {self._selected_model.displayName} has no functionCall defined")
response = await self._selected_model.functionCall(model_call)
else:
response_json = response_obj.json()
choice = response_json["choices"][0]
message = choice["message"]
content = message.get("content", "")
tool_calls_raw = message.get("tool_calls")
tool_calls = None
if tool_calls_raw:
tool_calls = []
for tc in tool_calls_raw:
func_data = tc.get("function", {})
func_name = func_data.get("name")
func_args_str = func_data.get("arguments", "{}")
try:
func_args = _json.loads(func_args_str) if isinstance(func_args_str, str) else func_args_str
except Exception:
func_args = {}
tool_calls.append({
"id": tc.get("id"),
"name": func_name,
"args": func_args,
})
response = AiModelResponse(
content=content or "",
success=True,
modelId=self._selected_model.name,
metadata={
"response_id": response_json.get("id", ""),
"tool_calls": tool_calls,
},
)
elif not tools and self._selected_model.connectorType in ("openai", "privatellm"):
# Streaming path for OpenAI/Ollama without tools (ChatGPT-like token streaming)
response = await self._call_openai_streaming(
ai_messages, run_manager, model_call, input_bytes, start_time
)
else:
# No tools or not OpenAI - use connector normally
if not self._selected_model.functionCall:
raise ValueError(f"Model {self._selected_model.displayName} has no functionCall defined")
response: AiModelResponse = await self._selected_model.functionCall(model_call)
if not response.success:
raise ValueError(f"AI model call failed: {response.error or 'Unknown error'}")
# Billing: compute price and invoke callback
output_bytes = len((response.content or "").encode("utf-8"))
processing_time = time.time() - start_time
price_chf = 0.0
if getattr(self._selected_model, "calculatepriceCHF", None):
try:
price_chf = self._selected_model.calculatepriceCHF(
processing_time, input_bytes, output_bytes
)
except Exception as e:
logger.warning(f"Billing: price calculation failed: {e}")
billing_callback = getattr(self, "_billing_callback", None)
if billing_callback:
try:
ai_response = AiCallResponse(
content=response.content or "",
modelName=self._selected_model.name,
provider=getattr(self._selected_model, "connectorType", "unknown") or "unknown",
priceCHF=price_chf,
processingTime=processing_time,
bytesSent=input_bytes,
bytesReceived=output_bytes,
errorCount=0,
)
billing_callback(ai_response)
except Exception as e:
logger.error(f"Billing callback error: {e}")
# Extract tool calls from response metadata if present
tool_calls = None
if response.metadata:
# Check for tool calls in metadata (format may vary by connector)
tool_calls = response.metadata.get("tool_calls") or response.metadata.get("function_calls")
# Convert response to LangChain format with tool calls
ai_message = self._convert_ai_response_to_langchain(response, tool_calls=tool_calls)
# Create generation and result
generation = ChatGeneration(message=ai_message)
return ChatResult(generations=[generation])
def bind_tools(self, tools: List[Any], **kwargs: Any) -> "AICenterChatModel":
"""
Bind tools to the model (required for LangGraph tool calling).
Args:
tools: List of LangChain tools
**kwargs: Additional arguments
Returns:
New instance with tools bound
"""
# Create a new instance with tools bound
# Note: The actual tool binding happens in LangGraph's ToolNode
# This method is called by LangGraph to prepare the model
bound_model = AICenterChatModel(
user=self.user,
operation_type=self.operation_type,
processing_mode=self.processing_mode,
billing_callback=getattr(self, "_billing_callback", None),
workflow_id=getattr(self, "_workflow_id", None),
)
object.__setattr__(bound_model, "_selected_model", self._selected_model)
# Store tools for potential use in message conversion
object.__setattr__(bound_model, "_bound_tools", tools)
return bound_model
def invoke(
self,
input: List[BaseMessage],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> BaseMessage:
"""
Synchronous invoke method (required by BaseChatModel).
Note: This is a wrapper around async _agenerate.
Args:
input: List of LangChain messages
config: Optional runnable config
**kwargs: Additional arguments
Returns:
AIMessage response
"""
import asyncio
# Try to get existing event loop
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# If loop is running, we need to use a different approach
# This shouldn't happen in LangGraph context, but handle it gracefully
raise RuntimeError("Cannot use synchronous invoke in async context. Use ainvoke instead.")
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run async generation
result = loop.run_until_complete(self._agenerate(input, **kwargs))
return result.generations[0].message
async def ainvoke(
self,
input: List[BaseMessage],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> BaseMessage:
"""
Async invoke method (required by BaseChatModel).
Args:
input: List of LangChain messages
config: Optional runnable config
**kwargs: Additional arguments
Returns:
AIMessage response
"""
result = await self._agenerate(input, **kwargs)
return result.generations[0].message