345 lines
12 KiB
Python
345 lines
12 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
LangGraph-based chatbot implementation.
|
|
Uses LangGraph workflow with AI Center integration and connector tools.
|
|
"""
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Annotated, AsyncIterator, Any, Optional, List
|
|
from pydantic import BaseModel
|
|
|
|
from langchain_core.messages import (
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
trim_messages,
|
|
)
|
|
from langgraph.graph.message import add_messages
|
|
from langgraph.graph import StateGraph, START, END
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
from langgraph.prebuilt import ToolNode
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
|
|
from modules.features.chatbot.aiCenterAdapter import AICenterChatModel
|
|
from modules.features.chatbot.langgraphTools import (
|
|
send_streaming_message,
|
|
create_sql_tool,
|
|
create_tavily_tools,
|
|
)
|
|
from modules.shared.configuration import APP_CONFIG
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChatState(BaseModel):
|
|
"""Represents the state of a chat session."""
|
|
|
|
messages: Annotated[List[BaseMessage], add_messages]
|
|
|
|
|
|
@dataclass
|
|
class LangGraphChatbot:
|
|
"""LangGraph-based chatbot with AI Center integration."""
|
|
|
|
model: AICenterChatModel
|
|
memory: Any
|
|
app: Optional[CompiledStateGraph] = None
|
|
system_prompt: str = "You are a helpful assistant."
|
|
|
|
@classmethod
|
|
async def create(
|
|
cls,
|
|
services,
|
|
system_prompt: str,
|
|
connector_instance,
|
|
enable_web_research: bool = True,
|
|
tavily_api_key: Optional[str] = None,
|
|
context_window_size: int = 8000,
|
|
) -> "LangGraphChatbot":
|
|
"""
|
|
Factory method to create and configure a LangGraphChatbot instance.
|
|
|
|
Args:
|
|
services: Services instance with AI access
|
|
system_prompt: The system prompt to initialize the chatbot
|
|
connector_instance: Database connector instance (PreprocessorConnector)
|
|
enable_web_research: Whether to enable web research tools
|
|
tavily_api_key: Tavily API key for web research (if None, uses APP_CONFIG)
|
|
context_window_size: Maximum context window size in tokens
|
|
|
|
Returns:
|
|
A configured LangGraphChatbot instance
|
|
"""
|
|
# Get Tavily API key from config if not provided
|
|
if tavily_api_key is None:
|
|
tavily_api_key = APP_CONFIG.get("Connector_AiTavily_API_SECRET")
|
|
|
|
# Create AI Center chat model adapter
|
|
model = AICenterChatModel(
|
|
services=services,
|
|
system_prompt=system_prompt,
|
|
temperature=0.2
|
|
)
|
|
|
|
# Create memory/checkpointer
|
|
memory = MemorySaver()
|
|
|
|
instance = LangGraphChatbot(
|
|
model=model,
|
|
memory=memory,
|
|
system_prompt=system_prompt,
|
|
)
|
|
|
|
# Configure tools
|
|
configured_tools = await instance._configure_tools(
|
|
connector_instance,
|
|
enable_web_research,
|
|
tavily_api_key
|
|
)
|
|
|
|
# Build LangGraph app
|
|
instance.app = instance._build_app(memory, configured_tools, context_window_size)
|
|
|
|
return instance
|
|
|
|
async def _configure_tools(
|
|
self,
|
|
connector_instance,
|
|
enable_web_research: bool,
|
|
tavily_api_key: Optional[str]
|
|
) -> List:
|
|
"""
|
|
Configure tools for the chatbot.
|
|
|
|
Args:
|
|
connector_instance: Database connector instance
|
|
enable_web_research: Whether web research is enabled
|
|
tavily_api_key: Tavily API key
|
|
|
|
Returns:
|
|
List of configured tools
|
|
"""
|
|
tools = []
|
|
|
|
# SQL tool using connector
|
|
sql_tool = create_sql_tool(connector_instance)
|
|
tools.append(sql_tool)
|
|
|
|
# Streaming message tool
|
|
tools.append(send_streaming_message)
|
|
|
|
# Tavily tools (if enabled)
|
|
if enable_web_research:
|
|
tavily_tools = create_tavily_tools(tavily_api_key, enable_web_research)
|
|
tools.extend(tavily_tools)
|
|
|
|
logger.info(f"Configured {len(tools)} tools for LangGraph chatbot")
|
|
return tools
|
|
|
|
def _build_app(
|
|
self,
|
|
memory: Any,
|
|
tools: List,
|
|
context_window_size: int
|
|
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
|
|
"""
|
|
Builds the chatbot application workflow using LangGraph.
|
|
|
|
Args:
|
|
memory: The chat memory/checkpointer to use
|
|
tools: The list of tools the chatbot can use
|
|
context_window_size: Maximum context window size
|
|
|
|
Returns:
|
|
A compiled state graph representing the chatbot application
|
|
"""
|
|
# Bind tools to model
|
|
llm_with_tools = self.model.bind_tools(tools=tools)
|
|
|
|
def select_window(msgs: List[BaseMessage]) -> List[BaseMessage]:
|
|
"""Selects a window of messages that fit within the context window size.
|
|
|
|
Args:
|
|
msgs: The list of messages to select from.
|
|
|
|
Returns:
|
|
A list of messages that fit within the context window size.
|
|
"""
|
|
def approx_counter(items: List[BaseMessage]) -> int:
|
|
"""Approximate token counter for messages.
|
|
|
|
Args:
|
|
items: List of messages to count tokens for.
|
|
|
|
Returns:
|
|
Approximate number of tokens in the messages.
|
|
"""
|
|
return sum(len(getattr(m, "content", "") or "") for m in items)
|
|
|
|
return trim_messages(
|
|
msgs,
|
|
strategy="last",
|
|
token_counter=approx_counter,
|
|
max_tokens=context_window_size,
|
|
start_on="human",
|
|
end_on=("human", "tool"),
|
|
include_system=True,
|
|
)
|
|
|
|
def agent_node(state: ChatState) -> dict:
|
|
"""Agent node for the chatbot workflow.
|
|
|
|
Args:
|
|
state: The current chat state.
|
|
|
|
Returns:
|
|
The updated chat state after processing.
|
|
"""
|
|
# Select the message window to fit in context (trim if needed)
|
|
window = select_window(state.messages)
|
|
|
|
# Ensure the system prompt is present at the start
|
|
if not window or not isinstance(window[0], SystemMessage):
|
|
window = [SystemMessage(content=self.system_prompt)] + window
|
|
|
|
# Call the LLM with tools
|
|
response = llm_with_tools.invoke(window)
|
|
|
|
# Return the new state
|
|
return {"messages": [response]}
|
|
|
|
def should_continue(state: ChatState) -> str:
|
|
"""Determines whether to continue the workflow or end it.
|
|
|
|
This conditional edge is called after the agent node to decide
|
|
whether to continue to the tools node (if the last message contains
|
|
tool calls) or to end the workflow (if no tool calls are present).
|
|
|
|
Args:
|
|
state: The current chat state.
|
|
|
|
Returns:
|
|
The next node to transition to ("tools" or END).
|
|
"""
|
|
# Get the last message
|
|
last_message = state.messages[-1]
|
|
|
|
# Check if the last message contains tool calls
|
|
# If so, continue to the tools node; otherwise, end the workflow
|
|
return "tools" if getattr(last_message, "tool_calls", None) else END
|
|
|
|
# Compose the workflow
|
|
workflow = StateGraph(ChatState)
|
|
workflow.add_node("agent", agent_node)
|
|
workflow.add_node("tools", ToolNode(tools=tools))
|
|
workflow.add_edge(START, "agent")
|
|
workflow.add_conditional_edges("agent", should_continue)
|
|
workflow.add_edge("tools", "agent")
|
|
|
|
return workflow.compile(checkpointer=memory)
|
|
|
|
async def chat(self, message: str, chat_id: str = "default") -> List[BaseMessage]:
|
|
"""
|
|
Process a chat message by calling the LLM and tools and returns the chat history.
|
|
|
|
Args:
|
|
message: The user message to process
|
|
chat_id: The chat thread ID
|
|
|
|
Returns:
|
|
The list of messages in the chat history
|
|
"""
|
|
if not self.app:
|
|
raise RuntimeError("Chatbot app not initialized. Call create() first.")
|
|
|
|
# Set the right thread ID for memory
|
|
config = {"configurable": {"thread_id": chat_id}}
|
|
|
|
# Single-turn chat (non-streaming)
|
|
result = await self.app.ainvoke(
|
|
{"messages": [HumanMessage(content=message)]}, config=config
|
|
)
|
|
|
|
# Extract and return the messages from the result
|
|
return result["messages"]
|
|
|
|
async def stream_events(
|
|
self, *, message: str, chat_id: str = "default"
|
|
) -> AsyncIterator[dict]:
|
|
"""
|
|
Stream UI-focused events using astream_events v2.
|
|
|
|
Args:
|
|
message: The user message to process
|
|
chat_id: Logical thread identifier; forwarded in the runnable config so
|
|
memory and tools are scoped per thread
|
|
|
|
Yields:
|
|
dict: One of:
|
|
- ``{"type": "status", "label": str}`` for short progress updates.
|
|
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
|
|
where ``chat_history`` only includes ``user``/``assistant`` roles.
|
|
- ``{"type": "error", "message": str}`` if an exception occurs.
|
|
"""
|
|
if not self.app:
|
|
raise RuntimeError("Chatbot app not initialized. Call create() first.")
|
|
|
|
# Thread-aware config for LangGraph/LangChain
|
|
config = {"configurable": {"thread_id": chat_id}}
|
|
|
|
def _is_root(ev: dict) -> bool:
|
|
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
|
return not ev.get("parent_ids")
|
|
|
|
try:
|
|
async for event in self.app.astream_events(
|
|
{"messages": [HumanMessage(content=message)]},
|
|
config=config,
|
|
version="v2",
|
|
):
|
|
etype = event.get("event")
|
|
ename = event.get("name") or ""
|
|
edata = event.get("data") or {}
|
|
|
|
# Stream human-readable progress via the special send_streaming_message tool
|
|
if etype == "on_tool_start" and ename == "send_streaming_message":
|
|
tool_in = edata.get("input") or {}
|
|
msg = tool_in.get("message")
|
|
if isinstance(msg, str) and msg.strip():
|
|
yield {"type": "status", "label": msg.strip()}
|
|
continue
|
|
|
|
# Emit the final payload when the root run finishes
|
|
if etype == "on_chain_end" and _is_root(event):
|
|
output_obj = edata.get("output")
|
|
|
|
# Extract message list from the graph's final output
|
|
final_msgs = output_obj.get("messages", []) if isinstance(output_obj, dict) else []
|
|
|
|
# Normalize for the frontend (only user/assistant with text content)
|
|
chat_history_payload: List[dict] = []
|
|
for m in final_msgs:
|
|
if isinstance(m, BaseMessage):
|
|
role = "user" if isinstance(m, HumanMessage) else "assistant" if isinstance(m, BaseMessage) else None
|
|
content = getattr(m, "content", "")
|
|
if role and content:
|
|
chat_history_payload.append({
|
|
"role": role,
|
|
"content": content
|
|
})
|
|
|
|
yield {
|
|
"type": "final",
|
|
"response": {
|
|
"thread": chat_id,
|
|
"chat_history": chat_history_payload,
|
|
},
|
|
}
|
|
return
|
|
|
|
except Exception as exc:
|
|
# Emit a single error envelope and end the stream
|
|
logger.error(f"Exception in stream_events: {exc}", exc_info=True)
|
|
yield {"type": "error", "message": f"Fehler beim Verarbeiten: {exc}"}
|