534 lines
22 KiB
Python
534 lines
22 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
AI Center to LangChain bridge.
|
|
Implements LangChain BaseChatModel interface using AI center models.
|
|
"""
|
|
|
|
import logging
|
|
import asyncio
|
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
from datetime import datetime
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
from langchain_core.messages import (
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
AIMessage,
|
|
ToolMessage,
|
|
convert_to_openai_messages,
|
|
)
|
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
|
from langchain_core.runnables import RunnableConfig
|
|
|
|
from modules.aicore.aicoreModelRegistry import modelRegistry
|
|
from modules.aicore.aicoreModelSelector import modelSelector
|
|
from modules.datamodels.datamodelAi import (
|
|
AiModel,
|
|
AiModelCall,
|
|
AiModelResponse,
|
|
AiCallOptions,
|
|
OperationTypeEnum,
|
|
ProcessingModeEnum,
|
|
)
|
|
from modules.datamodels.datamodelUam import User
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AICenterChatModel(BaseChatModel):
|
|
"""
|
|
LangChain-compatible chat model that uses AI center models.
|
|
Bridges AI center model selection and calling to LangChain's BaseChatModel interface.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
user: User,
|
|
operation_type: OperationTypeEnum = OperationTypeEnum.DATA_ANALYSE,
|
|
processing_mode: ProcessingModeEnum = ProcessingModeEnum.DETAILED,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Initialize the AI center chat model bridge.
|
|
|
|
Args:
|
|
user: Current user for RBAC and model selection
|
|
operation_type: Operation type for model selection
|
|
processing_mode: Processing mode for model selection
|
|
**kwargs: Additional arguments passed to BaseChatModel
|
|
"""
|
|
super().__init__(**kwargs)
|
|
# Use object.__setattr__ to bypass Pydantic validation for custom attributes
|
|
object.__setattr__(self, "user", user)
|
|
object.__setattr__(self, "operation_type", operation_type)
|
|
object.__setattr__(self, "processing_mode", processing_mode)
|
|
object.__setattr__(self, "_selected_model", None)
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of LLM."""
|
|
return "aicenter"
|
|
|
|
def _select_model(self, messages: List[BaseMessage]) -> AiModel:
|
|
"""
|
|
Select the best AI center model for the given messages.
|
|
Uses caching to avoid repeated model selection within same session.
|
|
|
|
Args:
|
|
messages: List of LangChain messages
|
|
|
|
Returns:
|
|
Selected AI model
|
|
"""
|
|
# Return cached model if already selected (significant performance improvement)
|
|
if self._selected_model is not None:
|
|
return self._selected_model
|
|
|
|
# Convert messages to prompt/context format for model selector
|
|
prompt_parts = []
|
|
context_parts = []
|
|
|
|
for msg in messages:
|
|
if isinstance(msg, SystemMessage):
|
|
prompt_parts.append(msg.content)
|
|
elif isinstance(msg, HumanMessage):
|
|
prompt_parts.append(msg.content)
|
|
elif isinstance(msg, AIMessage):
|
|
context_parts.append(msg.content)
|
|
elif isinstance(msg, ToolMessage):
|
|
context_parts.append(f"Tool {msg.name}: {msg.content}")
|
|
|
|
prompt = "\n".join(prompt_parts)
|
|
context = "\n".join(context_parts) if context_parts else ""
|
|
|
|
# Get available models with RBAC filtering
|
|
# Use cached/singleton interfaces for better performance
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
|
|
rootInterface = getRootInterface()
|
|
rbac_instance = rootInterface.rbac
|
|
|
|
available_models = modelRegistry.getAvailableModels(
|
|
currentUser=self.user,
|
|
rbacInstance=rbac_instance
|
|
)
|
|
|
|
# Create options for model selector
|
|
options = AiCallOptions(
|
|
operationType=self.operation_type,
|
|
processingMode=self.processing_mode
|
|
)
|
|
|
|
# Select model
|
|
selected_model = modelSelector.selectModel(
|
|
prompt=prompt,
|
|
context=context,
|
|
options=options,
|
|
availableModels=available_models
|
|
)
|
|
|
|
if not selected_model:
|
|
raise ValueError(f"No suitable model found for operation type {self.operation_type.value}")
|
|
|
|
logger.info(f"Selected AI center model: {selected_model.displayName} ({selected_model.name})")
|
|
object.__setattr__(self, "_selected_model", selected_model)
|
|
return selected_model
|
|
|
|
def _convert_messages_to_ai_format(self, messages: List[BaseMessage]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Convert LangChain messages to AI center format (OpenAI-style).
|
|
|
|
Args:
|
|
messages: List of LangChain messages
|
|
|
|
Returns:
|
|
List of messages in OpenAI format
|
|
"""
|
|
# Use LangChain's built-in conversion
|
|
openai_messages = convert_to_openai_messages(messages)
|
|
return openai_messages
|
|
|
|
def _convert_ai_response_to_langchain(
|
|
self,
|
|
response: AiModelResponse,
|
|
tool_calls: Optional[List[Dict[str, Any]]] = None
|
|
) -> AIMessage:
|
|
"""
|
|
Convert AI center response to LangChain AIMessage.
|
|
|
|
Args:
|
|
response: AI center response
|
|
tool_calls: Optional tool calls from the response (format: [{"id": "...", "name": "...", "args": {...}}])
|
|
|
|
Returns:
|
|
LangChain AIMessage with tool_calls if present
|
|
"""
|
|
# LangChain expects tool_calls in format: [{"id": "...", "name": "...", "args": {...}}]
|
|
# The tool_calls parameter should already be in this format
|
|
|
|
kwargs = {}
|
|
if tool_calls:
|
|
kwargs["tool_calls"] = tool_calls
|
|
|
|
return AIMessage(content=response.content or "", **kwargs)
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[Any] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""
|
|
Synchronous generate method required by BaseChatModel.
|
|
Wraps the async _agenerate method.
|
|
|
|
Args:
|
|
messages: List of LangChain messages
|
|
stop: Optional stop sequences
|
|
run_manager: Optional callback manager
|
|
**kwargs: Additional arguments
|
|
|
|
Returns:
|
|
ChatResult with generations
|
|
"""
|
|
# Try to get the current event loop
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
# If we're in an async context, raise an error
|
|
raise RuntimeError(
|
|
"AICenterChatModel._generate() called from async context. "
|
|
"Use _agenerate() instead."
|
|
)
|
|
except RuntimeError:
|
|
# No event loop, we can create one
|
|
pass
|
|
|
|
# Run the async method synchronously
|
|
return asyncio.run(self._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs))
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[Any] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""
|
|
Async generate method required by BaseChatModel.
|
|
|
|
Args:
|
|
messages: List of LangChain messages
|
|
stop: Optional stop sequences
|
|
run_manager: Optional callback manager
|
|
**kwargs: Additional arguments (may include tools for tool calling)
|
|
|
|
Returns:
|
|
ChatResult with generations
|
|
"""
|
|
# Select model if not already selected
|
|
if not self._selected_model:
|
|
self._select_model(messages)
|
|
|
|
# Check if tools are bound (for tool calling)
|
|
tools = getattr(self, "_bound_tools", None)
|
|
|
|
# Convert messages to AI center format
|
|
ai_messages = self._convert_messages_to_ai_format(messages)
|
|
|
|
# If tools are bound, add tool definitions to the system message
|
|
# This ensures the model knows about available tools
|
|
# Some models need explicit tool definitions to enable tool calling
|
|
if tools:
|
|
# Find or create system message
|
|
system_message_idx = None
|
|
for i, msg in enumerate(ai_messages):
|
|
if msg.get("role") == "system":
|
|
system_message_idx = i
|
|
break
|
|
|
|
# Build tool descriptions for the system message
|
|
tool_descriptions = []
|
|
for tool in tools:
|
|
if hasattr(tool, "name") and hasattr(tool, "description"):
|
|
# Get tool parameters for better description
|
|
args_schema = getattr(tool, "args_schema", None)
|
|
params_info = ""
|
|
if args_schema:
|
|
try:
|
|
if hasattr(args_schema, "model_json_schema"):
|
|
schema = args_schema.model_json_schema()
|
|
if "properties" in schema:
|
|
params = list(schema["properties"].keys())
|
|
params_info = f" (Parameter: {', '.join(params)})"
|
|
except:
|
|
pass
|
|
tool_descriptions.append(f"- {tool.name}: {tool.description}{params_info}")
|
|
|
|
if tool_descriptions:
|
|
tools_text = "\n".join(tool_descriptions)
|
|
tools_note = f"\n\n⚠️⚠️⚠️ KRITISCH - TOOL-NUTZUNG ⚠️⚠️⚠️\n\nVERFÜGBARE TOOLS:\n{tools_text}\n\nABSOLUT VERBINDLICH:\n- Du MUSST diese Tools verwenden, um Anfragen zu bearbeiten!\n- Für Status-Updates MUSST du IMMER das Tool 'send_streaming_message' verwenden!\n- VERBOTEN: Normale Text-Nachrichten für Status-Updates!\n- Du MUSST Tools aufrufen, nicht nur darüber sprechen!\n\nBeispiel FALSCH: \"Ich werde die Datenbank durchsuchen...\"\nBeispiel RICHTIG: Rufe das Tool 'send_streaming_message' mit \"Durchsuche Datenbank...\" auf!"
|
|
|
|
if system_message_idx is not None:
|
|
# Append to existing system message
|
|
ai_messages[system_message_idx]["content"] += tools_note
|
|
else:
|
|
# Add new system message at the beginning
|
|
ai_messages.insert(0, {
|
|
"role": "system",
|
|
"content": tools_note.strip()
|
|
})
|
|
|
|
# Convert LangChain tools to OpenAI tool format for potential use
|
|
# Note: The actual tool calling is handled by the connector if it supports it
|
|
# This conversion is kept for potential future use or connector support
|
|
openai_tools = None
|
|
if tools and self._selected_model.connectorType == "openai":
|
|
# Convert LangChain tools to OpenAI tool format
|
|
openai_tools = []
|
|
for tool in tools:
|
|
if hasattr(tool, "name") and hasattr(tool, "description"):
|
|
# Get tool parameters schema
|
|
args_schema = getattr(tool, "args_schema", None)
|
|
parameters = {}
|
|
if args_schema:
|
|
# Check if it's a Pydantic model class or instance
|
|
from pydantic import BaseModel
|
|
|
|
# Check if it's a class (not an instance)
|
|
if isinstance(args_schema, type) and issubclass(args_schema, BaseModel):
|
|
# It's a Pydantic model class - get JSON schema
|
|
if hasattr(args_schema, "model_json_schema"):
|
|
# Pydantic v2
|
|
parameters = args_schema.model_json_schema()
|
|
elif hasattr(args_schema, "schema"):
|
|
# Pydantic v1
|
|
parameters = args_schema.schema()
|
|
elif isinstance(args_schema, BaseModel):
|
|
# It's a Pydantic model instance
|
|
if hasattr(args_schema, "model_dump"):
|
|
# Pydantic v2
|
|
parameters = args_schema.model_dump()
|
|
elif hasattr(args_schema, "dict"):
|
|
# Pydantic v1
|
|
parameters = args_schema.dict()
|
|
elif hasattr(args_schema, "schema"):
|
|
# Has schema method (might be a class)
|
|
try:
|
|
parameters = args_schema.schema()
|
|
except TypeError:
|
|
# If schema() requires instance, try model_json_schema
|
|
if hasattr(args_schema, "model_json_schema"):
|
|
parameters = args_schema.model_json_schema()
|
|
else:
|
|
parameters = {}
|
|
elif isinstance(args_schema, dict):
|
|
# Already a dict
|
|
parameters = args_schema
|
|
|
|
tool_schema = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool.name,
|
|
"description": tool.description or "",
|
|
"parameters": parameters
|
|
}
|
|
}
|
|
openai_tools.append(tool_schema)
|
|
|
|
# Store tools for potential use by connector
|
|
# Note: The connector may need to access tools from the model_call
|
|
# This is a workaround since AiModelCall doesn't have a tools field
|
|
# Tools are added to system message above to ensure model knows about them
|
|
|
|
# Create model call
|
|
model_call = AiModelCall(
|
|
messages=ai_messages,
|
|
model=self._selected_model,
|
|
options=AiCallOptions(
|
|
operationType=self.operation_type,
|
|
processingMode=self.processing_mode,
|
|
temperature=self._selected_model.temperature
|
|
)
|
|
)
|
|
|
|
# If tools are bound and this is an OpenAI model, we need to call the API directly
|
|
# with tools included, since the connector interface doesn't support tools
|
|
if openai_tools and self._selected_model.connectorType == "openai":
|
|
# Call OpenAI API directly with tools (like legacy ChatAnthropic does)
|
|
import httpx
|
|
from modules.shared.configuration import APP_CONFIG
|
|
|
|
api_key = APP_CONFIG.get('Connector_AiOpenai_API_SECRET')
|
|
if not api_key:
|
|
raise ValueError("OpenAI API key not configured")
|
|
|
|
payload = {
|
|
"model": self._selected_model.name,
|
|
"messages": ai_messages,
|
|
"tools": openai_tools,
|
|
"tool_choice": "auto", # Let model decide when to use tools
|
|
"temperature": self._selected_model.temperature,
|
|
"max_tokens": self._selected_model.maxTokens
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=600.0) as client:
|
|
response_obj = await client.post(
|
|
self._selected_model.apiUrl,
|
|
headers={
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json=payload
|
|
)
|
|
|
|
if response_obj.status_code != 200:
|
|
error_msg = f"OpenAI API error: {response_obj.status_code} - {response_obj.text}"
|
|
logger.error(error_msg)
|
|
raise ValueError(error_msg)
|
|
|
|
response_json = response_obj.json()
|
|
choice = response_json["choices"][0]
|
|
message = choice["message"]
|
|
|
|
# Extract content and tool calls
|
|
content = message.get("content", "")
|
|
tool_calls_raw = message.get("tool_calls")
|
|
|
|
# Convert OpenAI tool_calls format to LangChain format
|
|
# LangChain expects: [{"id": "...", "name": "...", "args": {...}}]
|
|
tool_calls = None
|
|
if tool_calls_raw:
|
|
tool_calls = []
|
|
for tc in tool_calls_raw:
|
|
func_data = tc.get("function", {})
|
|
func_name = func_data.get("name")
|
|
func_args_str = func_data.get("arguments", "{}")
|
|
|
|
# Parse JSON arguments string to dict
|
|
import json
|
|
try:
|
|
func_args = json.loads(func_args_str) if isinstance(func_args_str, str) else func_args_str
|
|
except:
|
|
func_args = {}
|
|
|
|
tool_calls.append({
|
|
"id": tc.get("id"),
|
|
"name": func_name,
|
|
"args": func_args
|
|
})
|
|
|
|
# Create response object
|
|
response = AiModelResponse(
|
|
content=content or "",
|
|
success=True,
|
|
modelId=self._selected_model.name,
|
|
metadata={
|
|
"response_id": response_json.get("id", ""),
|
|
"tool_calls": tool_calls
|
|
}
|
|
)
|
|
else:
|
|
# No tools or not OpenAI - use connector normally
|
|
if not self._selected_model.functionCall:
|
|
raise ValueError(f"Model {self._selected_model.displayName} has no functionCall defined")
|
|
|
|
response: AiModelResponse = await self._selected_model.functionCall(model_call)
|
|
|
|
if not response.success:
|
|
raise ValueError(f"AI model call failed: {response.error or 'Unknown error'}")
|
|
|
|
# Extract tool calls from response metadata if present
|
|
tool_calls = None
|
|
if response.metadata:
|
|
# Check for tool calls in metadata (format may vary by connector)
|
|
tool_calls = response.metadata.get("tool_calls") or response.metadata.get("function_calls")
|
|
|
|
# Convert response to LangChain format with tool calls
|
|
ai_message = self._convert_ai_response_to_langchain(response, tool_calls=tool_calls)
|
|
|
|
# Create generation and result
|
|
generation = ChatGeneration(message=ai_message)
|
|
return ChatResult(generations=[generation])
|
|
|
|
def bind_tools(self, tools: List[Any], **kwargs: Any) -> "AICenterChatModel":
|
|
"""
|
|
Bind tools to the model (required for LangGraph tool calling).
|
|
|
|
Args:
|
|
tools: List of LangChain tools
|
|
**kwargs: Additional arguments
|
|
|
|
Returns:
|
|
New instance with tools bound
|
|
"""
|
|
# Create a new instance with tools bound
|
|
# Note: The actual tool binding happens in LangGraph's ToolNode
|
|
# This method is called by LangGraph to prepare the model
|
|
bound_model = AICenterChatModel(
|
|
user=self.user,
|
|
operation_type=self.operation_type,
|
|
processing_mode=self.processing_mode
|
|
)
|
|
object.__setattr__(bound_model, "_selected_model", self._selected_model)
|
|
# Store tools for potential use in message conversion
|
|
object.__setattr__(bound_model, "_bound_tools", tools)
|
|
return bound_model
|
|
|
|
def invoke(
|
|
self,
|
|
input: List[BaseMessage],
|
|
config: Optional[RunnableConfig] = None,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
"""
|
|
Synchronous invoke method (required by BaseChatModel).
|
|
Note: This is a wrapper around async _agenerate.
|
|
|
|
Args:
|
|
input: List of LangChain messages
|
|
config: Optional runnable config
|
|
**kwargs: Additional arguments
|
|
|
|
Returns:
|
|
AIMessage response
|
|
"""
|
|
import asyncio
|
|
|
|
# Try to get existing event loop
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
# If loop is running, we need to use a different approach
|
|
# This shouldn't happen in LangGraph context, but handle it gracefully
|
|
raise RuntimeError("Cannot use synchronous invoke in async context. Use ainvoke instead.")
|
|
except RuntimeError:
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
# Run async generation
|
|
result = loop.run_until_complete(self._agenerate(input, **kwargs))
|
|
return result.generations[0].message
|
|
|
|
async def ainvoke(
|
|
self,
|
|
input: List[BaseMessage],
|
|
config: Optional[RunnableConfig] = None,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
"""
|
|
Async invoke method (required by BaseChatModel).
|
|
|
|
Args:
|
|
input: List of LangChain messages
|
|
config: Optional runnable config
|
|
**kwargs: Additional arguments
|
|
|
|
Returns:
|
|
AIMessage response
|
|
"""
|
|
result = await self._agenerate(input, **kwargs)
|
|
return result.generations[0].message
|