# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ AI Center to LangChain bridge. Implements LangChain BaseChatModel interface using AI center models. """ import logging import asyncio import time from typing import Any, AsyncIterator, Callable, Dict, List, Optional from datetime import datetime from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( BaseMessage, HumanMessage, SystemMessage, AIMessage, ToolMessage, convert_to_openai_messages, ) from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.runnables import RunnableConfig from modules.aicore.aicoreModelRegistry import modelRegistry from modules.aicore.aicoreModelSelector import modelSelector from modules.datamodels.datamodelAi import ( AiModel, AiModelCall, AiModelResponse, AiCallResponse, AiCallOptions, OperationTypeEnum, ProcessingModeEnum, ) from modules.datamodels.datamodelUam import User logger = logging.getLogger(__name__) # Workflow-level store for allowed_providers and RBAC context (survives LangGraph/bind_tools # execution context where instance attributes may be lost when model is wrapped or serialized) _workflow_allowed_providers: Dict[str, List[str]] = {} _workflow_rbac_context: Dict[str, tuple] = {} # workflow_id -> (mandateId, featureInstanceId) def clear_workflow_allowed_providers(workflow_id: str) -> None: """Remove workflow from registry when stream completes to avoid memory growth.""" _workflow_allowed_providers.pop(workflow_id, None) class AICenterChatModel(BaseChatModel): """ LangChain-compatible chat model that uses AI center models. Bridges AI center model selection and calling to LangChain's BaseChatModel interface. """ def __init__( self, user: User, operation_type: OperationTypeEnum = OperationTypeEnum.DATA_ANALYSE, processing_mode: ProcessingModeEnum = ProcessingModeEnum.DETAILED, billing_callback: Optional[Callable[[AiCallResponse], None]] = None, workflow_id: Optional[str] = None, allowed_providers: Optional[List[str]] = None, prefer_fast_model: bool = False, mandate_id: Optional[str] = None, feature_instance_id: Optional[str] = None, **kwargs ): """ Initialize the AI center chat model bridge. Args: user: Current user for RBAC and model selection operation_type: Operation type for model selection processing_mode: Processing mode for model selection billing_callback: Optional callback invoked after each _agenerate with AiCallResponse for billing workflow_id: Optional workflow/conversation ID for billing context allowed_providers: Optional list of allowed provider connector types (empty/None = all) prefer_fast_model: When True, strongly prefer faster models (e.g. gpt-4o-mini for planner) **kwargs: Additional arguments passed to BaseChatModel """ super().__init__(**kwargs) # Use object.__setattr__ to bypass Pydantic validation for custom attributes object.__setattr__(self, "user", user) object.__setattr__(self, "operation_type", operation_type) object.__setattr__(self, "processing_mode", processing_mode) object.__setattr__(self, "_selected_model", None) object.__setattr__(self, "_billing_callback", billing_callback) object.__setattr__(self, "_workflow_id", workflow_id) object.__setattr__(self, "_allowed_providers", allowed_providers or []) object.__setattr__(self, "_prefer_fast_model", prefer_fast_model) object.__setattr__(self, "_mandate_id", mandate_id) object.__setattr__(self, "_feature_instance_id", feature_instance_id) # Store in workflow-level registry so it survives when instance attrs are lost (e.g. bind_tools) if workflow_id and allowed_providers: _workflow_allowed_providers[workflow_id] = list(allowed_providers) if workflow_id and (mandate_id is not None or feature_instance_id is not None): _workflow_rbac_context[workflow_id] = (mandate_id, feature_instance_id) @property def _llm_type(self) -> str: """Return type of LLM.""" return "aicenter" def _select_model(self, messages: List[BaseMessage]) -> AiModel: """ Select the best AI center model for the given messages. Uses caching to avoid repeated model selection within same session. Args: messages: List of LangChain messages Returns: Selected AI model """ # Return cached model if already selected (significant performance improvement) if self._selected_model is not None: return self._selected_model # Convert messages to prompt/context format for model selector prompt_parts = [] context_parts = [] for msg in messages: if isinstance(msg, SystemMessage): prompt_parts.append(msg.content) elif isinstance(msg, HumanMessage): prompt_parts.append(msg.content) elif isinstance(msg, AIMessage): context_parts.append(msg.content) elif isinstance(msg, ToolMessage): context_parts.append(f"Tool {msg.name}: {msg.content}") prompt = "\n".join(prompt_parts) context = "\n".join(context_parts) if context_parts else "" # Get available models with RBAC filtering # Use cached/singleton interfaces for better performance from modules.interfaces.interfaceDbApp import getRootInterface workflow_id = getattr(self, "_workflow_id", None) rootInterface = getRootInterface() rbac_instance = rootInterface.rbac mandate_id = getattr(self, "_mandate_id", None) feature_instance_id = getattr(self, "_feature_instance_id", None) if workflow_id and (mandate_id is None and feature_instance_id is None): ctx = _workflow_rbac_context.get(workflow_id) if ctx: mandate_id, feature_instance_id = ctx available_models = modelRegistry.getAvailableModels( currentUser=self.user, rbacInstance=rbac_instance, mandateId=mandate_id, featureInstanceId=feature_instance_id, ) # Allowed providers: instance attr or workflow store (lost in LangGraph/bind_tools context) allowed = ( (_workflow_allowed_providers.get(workflow_id) if workflow_id else None) or getattr(self, '_allowed_providers', None) or [] ) if allowed: logger.info(f"AICenterChatModel _select_model: applying allowedProviders={allowed}") filtered = [m for m in available_models if m.connectorType in allowed] if filtered: available_models = filtered else: logger.warning(f"No models match allowedProviders {allowed}, using all RBAC-permitted models") options = AiCallOptions( operationType=self.operation_type, processingMode=self.processing_mode, allowedProviders=allowed if allowed else None, preferFastModel=getattr(self, "_prefer_fast_model", False), ) # Select model selected_model = modelSelector.selectModel( prompt=prompt, context=context, options=options, availableModels=available_models ) if not selected_model: raise ValueError(f"No suitable model found for operation type {self.operation_type.value}") logger.info(f"Selected AI center model: {selected_model.displayName} ({selected_model.name})") object.__setattr__(self, "_selected_model", selected_model) return selected_model def _convert_messages_to_ai_format(self, messages: List[BaseMessage]) -> List[Dict[str, Any]]: """ Convert LangChain messages to AI center format (OpenAI-style). Args: messages: List of LangChain messages Returns: List of messages in OpenAI format """ # Use LangChain's built-in conversion openai_messages = convert_to_openai_messages(messages) return openai_messages def _convert_ai_response_to_langchain( self, response: AiModelResponse, tool_calls: Optional[List[Dict[str, Any]]] = None ) -> AIMessage: """ Convert AI center response to LangChain AIMessage. Args: response: AI center response tool_calls: Optional tool calls from the response (format: [{"id": "...", "name": "...", "args": {...}}]) Returns: LangChain AIMessage with tool_calls if present """ # LangChain expects tool_calls in format: [{"id": "...", "name": "...", "args": {...}}] # The tool_calls parameter should already be in this format kwargs = {} if tool_calls: kwargs["tool_calls"] = tool_calls return AIMessage(content=response.content or "", **kwargs) def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[Any] = None, **kwargs: Any, ) -> ChatResult: """ Synchronous generate method required by BaseChatModel. Wraps the async _agenerate method. Args: messages: List of LangChain messages stop: Optional stop sequences run_manager: Optional callback manager **kwargs: Additional arguments Returns: ChatResult with generations """ # Try to get the current event loop try: loop = asyncio.get_event_loop() if loop.is_running(): # If we're in an async context, raise an error raise RuntimeError( "AICenterChatModel._generate() called from async context. " "Use _agenerate() instead." ) except RuntimeError: # No event loop, we can create one pass # Run the async method synchronously return asyncio.run(self._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs)) async def _call_openai_streaming( self, ai_messages: List[dict], run_manager: Optional[Any], model_call: "AiModelCall", input_bytes: int, start_time: float, ) -> "AiModelResponse": """Call OpenAI/Ollama with stream=True, emit tokens via run_manager, return full response.""" import httpx import json as _json from modules.shared.configuration import APP_CONFIG if self._selected_model.connectorType == "openai": api_url = getattr(self._selected_model, "apiUrl", None) or "https://api.openai.com/v1/chat/completions" api_key = APP_CONFIG.get("Connector_AiOpenai_API_SECRET") if not api_key: raise ValueError("OpenAI API key not configured") headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} ollama_model = self._selected_model.name else: base_url = getattr(self._selected_model, "apiUrl", "").replace("/api/analyze", "") api_url = f"{base_url.rstrip('/')}/v1/chat/completions" api_key = APP_CONFIG.get("Connector_AiPrivateLlm_API_SECRET") headers = {"Content-Type": "application/json"} if api_key: headers["X-API-Key"] = api_key ollama_model = getattr(self._selected_model, "version", None) or self._selected_model.name payload = { "model": ollama_model, "messages": ai_messages, "temperature": self._selected_model.temperature, "max_tokens": self._selected_model.maxTokens, "stream": True, } content_parts: List[str] = [] async with httpx.AsyncClient(timeout=600.0) as client: async with client.stream("POST", api_url, headers=headers, json=payload) as resp: if resp.status_code != 200: raise ValueError(f"OpenAI stream error: {resp.status_code} - {await resp.aread()}") buffer = "" async for chunk in resp.aiter_text(): buffer += chunk while "\n" in buffer or "\r\n" in buffer: line, _, buffer = buffer.partition("\n") line = line.strip() if line.startswith("data: "): data_str = line[6:].strip() if data_str == "[DONE]": break try: data = _json.loads(data_str) choices = data.get("choices") or [] if choices: delta = choices[0].get("delta") or {} token = delta.get("content") or "" if token and run_manager and hasattr(run_manager, "on_llm_new_token"): run_manager.on_llm_new_token(token) content_parts.append(token) except _json.JSONDecodeError: pass content = "".join(content_parts) processing_time = time.time() - start_time output_bytes = len(content.encode("utf-8")) price_chf = 0.0 if getattr(self._selected_model, "calculatepriceCHF", None): try: price_chf = self._selected_model.calculatepriceCHF(processing_time, input_bytes, output_bytes) except Exception: pass billing_callback = getattr(self, "_billing_callback", None) if billing_callback: try: billing_callback(AiCallResponse( content=content, modelName=self._selected_model.name, provider=self._selected_model.connectorType or "unknown", priceCHF=price_chf, processingTime=processing_time, bytesSent=input_bytes, bytesReceived=output_bytes, errorCount=0, )) except Exception as e: logger.error(f"Billing callback error: {e}") return AiModelResponse(content=content, success=True, modelId=self._selected_model.name, metadata={}) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[Any] = None, **kwargs: Any, ) -> ChatResult: """ Async generate method required by BaseChatModel. Args: messages: List of LangChain messages stop: Optional stop sequences run_manager: Optional callback manager **kwargs: Additional arguments (may include tools for tool calling) Returns: ChatResult with generations """ # Select model if not already selected if not self._selected_model: self._select_model(messages) # Check if tools are bound (for tool calling) tools = getattr(self, "_bound_tools", None) # Convert messages to AI center format ai_messages = self._convert_messages_to_ai_format(messages) # Compute input bytes for billing (sum of message content lengths) input_bytes = sum( len((m.get("content") or "").encode("utf-8")) for m in ai_messages if isinstance(m.get("content"), str) ) start_time = time.time() # If tools are bound, add tool definitions to the system message # This ensures the model knows about available tools # Some models need explicit tool definitions to enable tool calling if tools: # Find or create system message system_message_idx = None for i, msg in enumerate(ai_messages): if msg.get("role") == "system": system_message_idx = i break # Build tool descriptions for the system message tool_descriptions = [] for tool in tools: if hasattr(tool, "name") and hasattr(tool, "description"): # Get tool parameters for better description args_schema = getattr(tool, "args_schema", None) params_info = "" if args_schema: try: if hasattr(args_schema, "model_json_schema"): schema = args_schema.model_json_schema() if "properties" in schema: params = list(schema["properties"].keys()) params_info = f" (Parameter: {', '.join(params)})" except: pass tool_descriptions.append(f"- {tool.name}: {tool.description}{params_info}") if tool_descriptions: tools_text = "\n".join(tool_descriptions) tools_note = f"\n\n⚠️⚠️⚠️ KRITISCH - TOOL-NUTZUNG ⚠️⚠️⚠️\n\nVERFÜGBARE TOOLS:\n{tools_text}\n\nABSOLUT VERBINDLICH:\n- Du MUSST diese Tools verwenden, um Anfragen zu bearbeiten!\n- Für Status-Updates MUSST du IMMER das Tool 'send_streaming_message' verwenden!\n- VERBOTEN: Normale Text-Nachrichten für Status-Updates!\n- Du MUSST Tools aufrufen, nicht nur darüber sprechen!\n\nBeispiel FALSCH: \"Ich werde die Datenbank durchsuchen...\"\nBeispiel RICHTIG: Rufe das Tool 'send_streaming_message' mit \"Durchsuche Datenbank...\" auf!" if system_message_idx is not None: # Append to existing system message ai_messages[system_message_idx]["content"] += tools_note else: # Add new system message at the beginning ai_messages.insert(0, { "role": "system", "content": tools_note.strip() }) # Convert LangChain tools to OpenAI/function-calling format (used by OpenAI and Ollama-compatible APIs) openai_tools = None if tools and self._selected_model.connectorType in ("openai", "privatellm"): # Build tool schema in OpenAI format (Ollama uses same format for tool calling) openai_tools = [] for tool in tools: if hasattr(tool, "name") and hasattr(tool, "description"): # Get tool parameters schema args_schema = getattr(tool, "args_schema", None) parameters = {} if args_schema: # Check if it's a Pydantic model class or instance from pydantic import BaseModel # Check if it's a class (not an instance) if isinstance(args_schema, type) and issubclass(args_schema, BaseModel): # It's a Pydantic model class - get JSON schema if hasattr(args_schema, "model_json_schema"): # Pydantic v2 parameters = args_schema.model_json_schema() elif hasattr(args_schema, "schema"): # Pydantic v1 parameters = args_schema.schema() elif isinstance(args_schema, BaseModel): # It's a Pydantic model instance if hasattr(args_schema, "model_dump"): # Pydantic v2 parameters = args_schema.model_dump() elif hasattr(args_schema, "dict"): # Pydantic v1 parameters = args_schema.dict() elif hasattr(args_schema, "schema"): # Has schema method (might be a class) try: parameters = args_schema.schema() except TypeError: # If schema() requires instance, try model_json_schema if hasattr(args_schema, "model_json_schema"): parameters = args_schema.model_json_schema() else: parameters = {} elif isinstance(args_schema, dict): # Already a dict parameters = args_schema tool_schema = { "type": "function", "function": { "name": tool.name, "description": tool.description or "", "parameters": parameters } } openai_tools.append(tool_schema) # Store tools for potential use by connector # Note: The connector may need to access tools from the model_call # This is a workaround since AiModelCall doesn't have a tools field # Tools are added to system message above to ensure model knows about them # Create model call model_call = AiModelCall( messages=ai_messages, model=self._selected_model, options=AiCallOptions( operationType=self.operation_type, processingMode=self.processing_mode, temperature=self._selected_model.temperature ) ) # If tools are bound, use OpenAI-compatible API (OpenAI or Private-LLM Ollama endpoint) if openai_tools and self._selected_model.connectorType in ("openai", "privatellm"): import httpx import json as _json from modules.shared.configuration import APP_CONFIG if self._selected_model.connectorType == "openai": api_url = self._selected_model.apiUrl api_key = APP_CONFIG.get("Connector_AiOpenai_API_SECRET") if not api_key: raise ValueError("OpenAI API key not configured") headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} ollama_model = self._selected_model.name else: # privatellm: use Ollama OpenAI-compatible /v1/chat/completions (same service, same provider) base_url = self._selected_model.apiUrl.replace("/api/analyze", "") api_url = f"{base_url}/v1/chat/completions" api_key = APP_CONFIG.get("Connector_AiPrivateLlm_API_SECRET") headers = {"Content-Type": "application/json"} if api_key: headers["X-API-Key"] = api_key # Ollama needs the underlying model name (e.g. qwen2.5:7b), not poweron-text-general ollama_model = getattr(self._selected_model, "version", None) or self._selected_model.name payload = { "model": ollama_model, "messages": ai_messages, "tools": openai_tools, "tool_choice": "auto", "temperature": self._selected_model.temperature, "max_tokens": self._selected_model.maxTokens, } use_connector_fallback = False async with httpx.AsyncClient(timeout=600.0) as client: response_obj = await client.post(api_url, headers=headers, json=payload) if response_obj.status_code == 404 and self._selected_model.connectorType == "privatellm": logger.warning( "Private-LLM /v1/chat/completions not found (404). Falling back to /api/analyze. " "Tool calling will not work until the service exposes an OpenAI-compatible endpoint." ) use_connector_fallback = True elif response_obj.status_code != 200: error_msg = f"AI API error ({self._selected_model.connectorType}): {response_obj.status_code} - {response_obj.text}" logger.error(error_msg) raise ValueError(error_msg) if use_connector_fallback: if not self._selected_model.functionCall: raise ValueError(f"Model {self._selected_model.displayName} has no functionCall defined") response = await self._selected_model.functionCall(model_call) else: response_json = response_obj.json() choice = response_json["choices"][0] message = choice["message"] content = message.get("content", "") tool_calls_raw = message.get("tool_calls") tool_calls = None if tool_calls_raw: tool_calls = [] for tc in tool_calls_raw: func_data = tc.get("function", {}) func_name = func_data.get("name") func_args_str = func_data.get("arguments", "{}") try: func_args = _json.loads(func_args_str) if isinstance(func_args_str, str) else func_args_str except Exception: func_args = {} tool_calls.append({ "id": tc.get("id"), "name": func_name, "args": func_args, }) response = AiModelResponse( content=content or "", success=True, modelId=self._selected_model.name, metadata={ "response_id": response_json.get("id", ""), "tool_calls": tool_calls, }, ) elif not tools and self._selected_model.connectorType in ("openai", "privatellm"): # Streaming path for OpenAI/Ollama without tools (ChatGPT-like token streaming) response = await self._call_openai_streaming( ai_messages, run_manager, model_call, input_bytes, start_time ) else: # No tools or not OpenAI - use connector normally if not self._selected_model.functionCall: raise ValueError(f"Model {self._selected_model.displayName} has no functionCall defined") response: AiModelResponse = await self._selected_model.functionCall(model_call) if not response.success: raise ValueError(f"AI model call failed: {response.error or 'Unknown error'}") # Billing: compute price and invoke callback output_bytes = len((response.content or "").encode("utf-8")) processing_time = time.time() - start_time price_chf = 0.0 if getattr(self._selected_model, "calculatepriceCHF", None): try: price_chf = self._selected_model.calculatepriceCHF( processing_time, input_bytes, output_bytes ) except Exception as e: logger.warning(f"Billing: price calculation failed: {e}") billing_callback = getattr(self, "_billing_callback", None) if billing_callback: try: ai_response = AiCallResponse( content=response.content or "", modelName=self._selected_model.name, provider=getattr(self._selected_model, "connectorType", "unknown") or "unknown", priceCHF=price_chf, processingTime=processing_time, bytesSent=input_bytes, bytesReceived=output_bytes, errorCount=0, ) billing_callback(ai_response) except Exception as e: logger.error(f"Billing callback error: {e}") # Extract tool calls from response metadata if present tool_calls = None if response.metadata: # Check for tool calls in metadata (format may vary by connector) tool_calls = response.metadata.get("tool_calls") or response.metadata.get("function_calls") # Convert response to LangChain format with tool calls ai_message = self._convert_ai_response_to_langchain(response, tool_calls=tool_calls) # Create generation and result generation = ChatGeneration(message=ai_message) return ChatResult(generations=[generation]) def bind_tools(self, tools: List[Any], **kwargs: Any) -> "AICenterChatModel": """ Bind tools to the model (required for LangGraph tool calling). Args: tools: List of LangChain tools **kwargs: Additional arguments Returns: New instance with tools bound """ # Create a new instance with tools bound # Note: The actual tool binding happens in LangGraph's ToolNode # This method is called by LangGraph to prepare the model bound_model = AICenterChatModel( user=self.user, operation_type=self.operation_type, processing_mode=self.processing_mode, billing_callback=getattr(self, "_billing_callback", None), workflow_id=getattr(self, "_workflow_id", None), ) object.__setattr__(bound_model, "_selected_model", self._selected_model) # Store tools for potential use in message conversion object.__setattr__(bound_model, "_bound_tools", tools) return bound_model def invoke( self, input: List[BaseMessage], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> BaseMessage: """ Synchronous invoke method (required by BaseChatModel). Note: This is a wrapper around async _agenerate. Args: input: List of LangChain messages config: Optional runnable config **kwargs: Additional arguments Returns: AIMessage response """ import asyncio # Try to get existing event loop try: loop = asyncio.get_event_loop() if loop.is_running(): # If loop is running, we need to use a different approach # This shouldn't happen in LangGraph context, but handle it gracefully raise RuntimeError("Cannot use synchronous invoke in async context. Use ainvoke instead.") except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) # Run async generation result = loop.run_until_complete(self._agenerate(input, **kwargs)) return result.generations[0].message async def ainvoke( self, input: List[BaseMessage], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> BaseMessage: """ Async invoke method (required by BaseChatModel). Args: input: List of LangChain messages config: Optional runnable config **kwargs: Additional arguments Returns: AIMessage response """ result = await self._agenerate(input, **kwargs) return result.generations[0].message