# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ LangGraph-based chatbot implementation. Uses LangGraph workflow with AI Center integration and connector tools. """ import logging from dataclasses import dataclass from typing import Annotated, AsyncIterator, Any, Optional, List from pydantic import BaseModel from langchain_core.messages import ( BaseMessage, HumanMessage, SystemMessage, trim_messages, ) from langgraph.graph.message import add_messages from langgraph.graph import StateGraph, START, END from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt import ToolNode from langgraph.checkpoint.memory import MemorySaver from modules.features.chatbot.aiCenterAdapter import AICenterChatModel from modules.features.chatbot.langgraphTools import ( send_streaming_message, create_sql_tool, create_tavily_tools, ) from modules.shared.configuration import APP_CONFIG logger = logging.getLogger(__name__) class ChatState(BaseModel): """Represents the state of a chat session.""" messages: Annotated[List[BaseMessage], add_messages] @dataclass class LangGraphChatbot: """LangGraph-based chatbot with AI Center integration.""" model: AICenterChatModel memory: Any app: Optional[CompiledStateGraph] = None system_prompt: str = "You are a helpful assistant." @classmethod async def create( cls, services, system_prompt: str, connector_instance, enable_web_research: bool = True, tavily_api_key: Optional[str] = None, context_window_size: int = 8000, ) -> "LangGraphChatbot": """ Factory method to create and configure a LangGraphChatbot instance. Args: services: Services instance with AI access system_prompt: The system prompt to initialize the chatbot connector_instance: Database connector instance (PreprocessorConnector) enable_web_research: Whether to enable web research tools tavily_api_key: Tavily API key for web research (if None, uses APP_CONFIG) context_window_size: Maximum context window size in tokens Returns: A configured LangGraphChatbot instance """ # Get Tavily API key from config if not provided if tavily_api_key is None: tavily_api_key = APP_CONFIG.get("Connector_AiTavily_API_SECRET") # Create AI Center chat model adapter model = AICenterChatModel( services=services, system_prompt=system_prompt, temperature=0.2 ) # Create memory/checkpointer memory = MemorySaver() instance = LangGraphChatbot( model=model, memory=memory, system_prompt=system_prompt, ) # Configure tools configured_tools = await instance._configure_tools( connector_instance, enable_web_research, tavily_api_key ) # Build LangGraph app instance.app = instance._build_app(memory, configured_tools, context_window_size) return instance async def _configure_tools( self, connector_instance, enable_web_research: bool, tavily_api_key: Optional[str] ) -> List: """ Configure tools for the chatbot. Args: connector_instance: Database connector instance enable_web_research: Whether web research is enabled tavily_api_key: Tavily API key Returns: List of configured tools """ tools = [] # SQL tool using connector sql_tool = create_sql_tool(connector_instance) tools.append(sql_tool) # Streaming message tool tools.append(send_streaming_message) # Tavily tools (if enabled) if enable_web_research: tavily_tools = create_tavily_tools(tavily_api_key, enable_web_research) tools.extend(tavily_tools) logger.info(f"Configured {len(tools)} tools for LangGraph chatbot") return tools def _build_app( self, memory: Any, tools: List, context_window_size: int ) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]: """ Builds the chatbot application workflow using LangGraph. Args: memory: The chat memory/checkpointer to use tools: The list of tools the chatbot can use context_window_size: Maximum context window size Returns: A compiled state graph representing the chatbot application """ # Bind tools to model llm_with_tools = self.model.bind_tools(tools=tools) def select_window(msgs: List[BaseMessage]) -> List[BaseMessage]: """Selects a window of messages that fit within the context window size. Args: msgs: The list of messages to select from. Returns: A list of messages that fit within the context window size. """ def approx_counter(items: List[BaseMessage]) -> int: """Approximate token counter for messages. Args: items: List of messages to count tokens for. Returns: Approximate number of tokens in the messages. """ return sum(len(getattr(m, "content", "") or "") for m in items) return trim_messages( msgs, strategy="last", token_counter=approx_counter, max_tokens=context_window_size, start_on="human", end_on=("human", "tool"), include_system=True, ) def agent_node(state: ChatState) -> dict: """Agent node for the chatbot workflow. Args: state: The current chat state. Returns: The updated chat state after processing. """ # Select the message window to fit in context (trim if needed) window = select_window(state.messages) # Ensure the system prompt is present at the start if not window or not isinstance(window[0], SystemMessage): window = [SystemMessage(content=self.system_prompt)] + window # Call the LLM with tools response = llm_with_tools.invoke(window) # Return the new state return {"messages": [response]} def should_continue(state: ChatState) -> str: """Determines whether to continue the workflow or end it. This conditional edge is called after the agent node to decide whether to continue to the tools node (if the last message contains tool calls) or to end the workflow (if no tool calls are present). Args: state: The current chat state. Returns: The next node to transition to ("tools" or END). """ # Get the last message last_message = state.messages[-1] # Check if the last message contains tool calls # If so, continue to the tools node; otherwise, end the workflow return "tools" if getattr(last_message, "tool_calls", None) else END # Compose the workflow workflow = StateGraph(ChatState) workflow.add_node("agent", agent_node) workflow.add_node("tools", ToolNode(tools=tools)) workflow.add_edge(START, "agent") workflow.add_conditional_edges("agent", should_continue) workflow.add_edge("tools", "agent") return workflow.compile(checkpointer=memory) async def chat(self, message: str, chat_id: str = "default") -> List[BaseMessage]: """ Process a chat message by calling the LLM and tools and returns the chat history. Args: message: The user message to process chat_id: The chat thread ID Returns: The list of messages in the chat history """ if not self.app: raise RuntimeError("Chatbot app not initialized. Call create() first.") # Set the right thread ID for memory config = {"configurable": {"thread_id": chat_id}} # Single-turn chat (non-streaming) result = await self.app.ainvoke( {"messages": [HumanMessage(content=message)]}, config=config ) # Extract and return the messages from the result return result["messages"] async def stream_events( self, *, message: str, chat_id: str = "default" ) -> AsyncIterator[dict]: """ Stream UI-focused events using astream_events v2. Args: message: The user message to process chat_id: Logical thread identifier; forwarded in the runnable config so memory and tools are scoped per thread Yields: dict: One of: - ``{"type": "status", "label": str}`` for short progress updates. - ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}`` where ``chat_history`` only includes ``user``/``assistant`` roles. - ``{"type": "error", "message": str}`` if an exception occurs. """ if not self.app: raise RuntimeError("Chatbot app not initialized. Call create() first.") # Thread-aware config for LangGraph/LangChain config = {"configurable": {"thread_id": chat_id}} def _is_root(ev: dict) -> bool: """Return True if the event is from the root run (v2: empty parent_ids).""" return not ev.get("parent_ids") try: async for event in self.app.astream_events( {"messages": [HumanMessage(content=message)]}, config=config, version="v2", ): etype = event.get("event") ename = event.get("name") or "" edata = event.get("data") or {} # Stream human-readable progress via the special send_streaming_message tool if etype == "on_tool_start" and ename == "send_streaming_message": tool_in = edata.get("input") or {} msg = tool_in.get("message") if isinstance(msg, str) and msg.strip(): yield {"type": "status", "label": msg.strip()} continue # Emit the final payload when the root run finishes if etype == "on_chain_end" and _is_root(event): output_obj = edata.get("output") # Extract message list from the graph's final output final_msgs = output_obj.get("messages", []) if isinstance(output_obj, dict) else [] # Normalize for the frontend (only user/assistant with text content) chat_history_payload: List[dict] = [] for m in final_msgs: if isinstance(m, BaseMessage): role = "user" if isinstance(m, HumanMessage) else "assistant" if isinstance(m, BaseMessage) else None content = getattr(m, "content", "") if role and content: chat_history_payload.append({ "role": role, "content": content }) yield { "type": "final", "response": { "thread": chat_id, "chat_history": chat_history_payload, }, } return except Exception as exc: # Emit a single error envelope and end the stream logger.error(f"Exception in stream_events: {exc}", exc_info=True) yield {"type": "error", "message": f"Fehler beim Verarbeiten: {exc}"}