455 lines
18 KiB
Python
455 lines
18 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Chatbot domain logic."""
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import Annotated, AsyncIterator, Any, List, Optional, TYPE_CHECKING
|
|
from pydantic import BaseModel
|
|
|
|
from langchain_core.messages import (
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
trim_messages,
|
|
)
|
|
from langgraph.graph.message import add_messages
|
|
from langgraph.graph import StateGraph, START, END
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
from langgraph.prebuilt import ToolNode
|
|
|
|
from modules.features.chatbot.bridges.ai import AICenterChatModel
|
|
from modules.features.chatbot.bridges.memory import DatabaseCheckpointer
|
|
from modules.features.chatbot.bridges.tools import (
|
|
create_sql_query_tool,
|
|
create_tavily_search_tool,
|
|
create_send_streaming_message_tool,
|
|
)
|
|
from modules.features.chatbot.streaming.helpers import ChatStreamingHelper
|
|
from modules.features.chatbot.streaming.events import get_event_manager
|
|
from modules.datamodels.datamodelUam import User
|
|
|
|
if TYPE_CHECKING:
|
|
from modules.features.chatbot.config import ChatbotConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChatState(BaseModel):
|
|
"""Represents the state of a chat session."""
|
|
|
|
messages: Annotated[List[BaseMessage], add_messages]
|
|
|
|
|
|
@dataclass
|
|
class Chatbot:
|
|
"""Represents a chatbot."""
|
|
|
|
model: AICenterChatModel
|
|
memory: DatabaseCheckpointer
|
|
app: CompiledStateGraph = None
|
|
system_prompt: str = "You are a helpful assistant."
|
|
workflow_id: str = "default"
|
|
config: Optional["ChatbotConfig"] = None
|
|
|
|
@classmethod
|
|
async def create(
|
|
cls,
|
|
model: AICenterChatModel,
|
|
memory: DatabaseCheckpointer,
|
|
system_prompt: str,
|
|
workflow_id: str = "default",
|
|
config: Optional["ChatbotConfig"] = None,
|
|
) -> "Chatbot":
|
|
"""Factory method to create and configure a Chatbot instance.
|
|
|
|
Args:
|
|
model: The chat model to use (AICenterChatModel).
|
|
memory: The chat memory to use (DatabaseCheckpointer).
|
|
system_prompt: The system prompt to initialize the chatbot.
|
|
workflow_id: The workflow ID (maps to thread_id).
|
|
config: Optional chatbot configuration for dynamic tool enablement.
|
|
|
|
Returns:
|
|
A configured Chatbot instance.
|
|
"""
|
|
instance = Chatbot(
|
|
model=model,
|
|
memory=memory,
|
|
system_prompt=system_prompt,
|
|
workflow_id=workflow_id,
|
|
config=config,
|
|
)
|
|
configured_tools = await instance._configure_tools()
|
|
instance.app = instance._build_app(memory, configured_tools)
|
|
return instance
|
|
|
|
async def _configure_tools(self) -> List[Any]:
|
|
"""Configure tools for the chatbot based on config.
|
|
|
|
Returns:
|
|
List of configured tools based on config settings.
|
|
"""
|
|
tools = []
|
|
|
|
# Get tool enablement from config (use defaults if no config)
|
|
sql_enabled = True
|
|
tavily_enabled = False
|
|
streaming_enabled = True
|
|
connector_type = "preprocessor"
|
|
|
|
if self.config:
|
|
sql_enabled = self.config.tools.is_sql_enabled()
|
|
tavily_enabled = self.config.tools.is_tavily_enabled()
|
|
streaming_enabled = self.config.tools.is_streaming_enabled()
|
|
connector_type = self.config.database.connector
|
|
|
|
logger.info(f"Chatbot tools config - SQL: {sql_enabled}, Tavily: {tavily_enabled}, "
|
|
f"Streaming: {streaming_enabled}, Connector: {connector_type}")
|
|
|
|
# SQL query tool (if enabled)
|
|
if sql_enabled:
|
|
sql_tool = create_sql_query_tool(connector_type=connector_type)
|
|
tools.append(sql_tool)
|
|
logger.debug(f"Added SQL query tool with connector: {connector_type}")
|
|
|
|
# Tavily search tool (if enabled)
|
|
if tavily_enabled:
|
|
tavily_tool = create_tavily_search_tool()
|
|
tools.append(tavily_tool)
|
|
logger.debug("Added Tavily search tool")
|
|
|
|
# Streaming status tool (if enabled)
|
|
if streaming_enabled:
|
|
event_manager = get_event_manager()
|
|
send_streaming_message = create_send_streaming_message_tool(event_manager)
|
|
tools.append(send_streaming_message)
|
|
logger.debug("Added streaming status tool")
|
|
|
|
logger.info(f"Configured {len(tools)} tools for chatbot workflow {self.workflow_id}")
|
|
return tools
|
|
|
|
def _build_app(
|
|
self, memory: DatabaseCheckpointer, tools: List[Any]
|
|
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
|
|
"""Builds the chatbot application workflow using LangGraph.
|
|
|
|
Args:
|
|
memory: The chat memory to use.
|
|
tools: The list of tools the chatbot can use.
|
|
|
|
Returns:
|
|
A compiled state graph representing the chatbot application.
|
|
"""
|
|
llm_with_tools = self.model.bind_tools(tools=tools)
|
|
|
|
def select_window(msgs: List[BaseMessage]) -> List[BaseMessage]:
|
|
"""Selects a window of messages that fit within the context window size.
|
|
|
|
Args:
|
|
msgs: The list of messages to select from.
|
|
|
|
Returns:
|
|
A list of messages that fit within the context window size.
|
|
"""
|
|
|
|
def approx_counter(items: List[BaseMessage]) -> int:
|
|
"""Approximate token counter for messages.
|
|
|
|
Args:
|
|
items: List of messages to count tokens for.
|
|
|
|
Returns:
|
|
Approximate number of tokens in the messages.
|
|
"""
|
|
return sum(len(getattr(m, "content", "") or "") for m in items)
|
|
|
|
# Use model's context length if available, otherwise default
|
|
max_tokens = getattr(self.model._selected_model, "contextLength", 128000) if hasattr(self.model, "_selected_model") and self.model._selected_model else 128000
|
|
|
|
return trim_messages(
|
|
msgs,
|
|
strategy="last",
|
|
token_counter=approx_counter,
|
|
max_tokens=int(max_tokens * 0.8), # Use 80% of context window
|
|
start_on="human",
|
|
end_on=("human", "tool"),
|
|
include_system=True,
|
|
)
|
|
|
|
async def agent_node(state: ChatState) -> dict:
|
|
"""Agent node for the chatbot workflow.
|
|
|
|
Args:
|
|
state: The current chat state.
|
|
|
|
Returns:
|
|
The updated chat state after processing.
|
|
"""
|
|
# Select the message window to fit in context (trim if needed)
|
|
window = select_window(state.messages)
|
|
|
|
# Ensure the system prompt is present at the start
|
|
if not window or not isinstance(window[0], SystemMessage):
|
|
window = [SystemMessage(content=self.system_prompt)] + window
|
|
|
|
# Call the LLM with tools (use ainvoke for async)
|
|
response = await llm_with_tools.ainvoke(window)
|
|
|
|
# Return the new state
|
|
return {"messages": [response]}
|
|
|
|
def should_continue(state: ChatState) -> str:
|
|
"""Determines whether to continue the workflow or end it.
|
|
|
|
This conditional edge is called after the agent node to decide
|
|
whether to continue to the tools node (if the last message contains
|
|
tool calls) or to end the workflow (if no tool calls are present).
|
|
|
|
Args:
|
|
state: The current chat state.
|
|
|
|
Returns:
|
|
The next node to transition to ("tools" or END).
|
|
"""
|
|
# Get the last message
|
|
last_message = state.messages[-1]
|
|
|
|
# Check if the last message contains tool calls
|
|
# If so, continue to the tools node; otherwise, end the workflow
|
|
return "tools" if getattr(last_message, "tool_calls", None) else END
|
|
|
|
async def tools_with_retry(state: ChatState) -> dict:
|
|
"""Tools node with parallel execution and retry logic.
|
|
|
|
Args:
|
|
state: The current chat state.
|
|
|
|
Returns:
|
|
The updated chat state after tool execution.
|
|
"""
|
|
import asyncio
|
|
|
|
# Get tool calls from the last message
|
|
last_message = state.messages[-1]
|
|
tool_calls = getattr(last_message, "tool_calls", [])
|
|
|
|
if not tool_calls:
|
|
return {"messages": []}
|
|
|
|
# Create a lookup for tools by name
|
|
tools_by_name = {t.name: t for t in tools}
|
|
|
|
async def execute_single_tool(tool_call):
|
|
"""Execute a single tool call."""
|
|
tool_name = tool_call.get("name") or tool_call.get("function", {}).get("name")
|
|
tool_id = tool_call.get("id", f"call_{tool_name}")
|
|
args = tool_call.get("args") or tool_call.get("function", {}).get("arguments", {})
|
|
|
|
if isinstance(args, str):
|
|
import json
|
|
try:
|
|
args = json.loads(args)
|
|
except:
|
|
args = {"input": args}
|
|
|
|
tool = tools_by_name.get(tool_name)
|
|
if not tool:
|
|
return ToolMessage(
|
|
content=f"Error: Tool '{tool_name}' not found",
|
|
tool_call_id=tool_id,
|
|
name=tool_name
|
|
)
|
|
|
|
try:
|
|
# Execute tool asynchronously
|
|
if asyncio.iscoroutinefunction(tool.coroutine):
|
|
result = await tool.coroutine(**args)
|
|
elif hasattr(tool, 'ainvoke'):
|
|
result = await tool.ainvoke(args)
|
|
else:
|
|
result = tool.invoke(args)
|
|
|
|
return ToolMessage(
|
|
content=str(result),
|
|
tool_call_id=tool_id,
|
|
name=tool_name
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Tool {tool_name} failed: {e}")
|
|
return ToolMessage(
|
|
content=f"Error executing {tool_name}: {str(e)}",
|
|
tool_call_id=tool_id,
|
|
name=tool_name
|
|
)
|
|
|
|
# Execute ALL tool calls in parallel
|
|
logger.info(f"Executing {len(tool_calls)} tool calls in parallel")
|
|
tool_messages = await asyncio.gather(
|
|
*[execute_single_tool(tc) for tc in tool_calls],
|
|
return_exceptions=True
|
|
)
|
|
|
|
# Convert exceptions to error messages
|
|
result_messages = []
|
|
for i, msg in enumerate(tool_messages):
|
|
if isinstance(msg, Exception):
|
|
tool_call = tool_calls[i]
|
|
tool_name = tool_call.get("name", "unknown")
|
|
tool_id = tool_call.get("id", f"call_{i}")
|
|
result_messages.append(ToolMessage(
|
|
content=f"Error: {str(msg)}",
|
|
tool_call_id=tool_id,
|
|
name=tool_name
|
|
))
|
|
else:
|
|
result_messages.append(msg)
|
|
|
|
result = {"messages": result_messages}
|
|
|
|
# Check if we got no results and should retry
|
|
no_results_keywords = [
|
|
"returned 0 rows",
|
|
"no data",
|
|
"keine artikel gefunden",
|
|
"keine ergebnisse"
|
|
]
|
|
|
|
# Check tool results for no data
|
|
for msg in result.get("messages", []):
|
|
content = getattr(msg, "content", "")
|
|
if isinstance(content, str):
|
|
content_lower = content.lower()
|
|
if any(keyword in content_lower for keyword in no_results_keywords):
|
|
# Check if we haven't retried yet (avoid infinite loops)
|
|
retry_count = sum(1 for m in state.messages if "retry" in str(getattr(m, "content", "")).lower())
|
|
if retry_count < 2: # Allow max 2 retries
|
|
logger.info("No results found in tool output, adding retry instruction")
|
|
retry_message = HumanMessage(
|
|
content="WICHTIG: Die vorherige Suche hat keine Ergebnisse gefunden. "
|
|
"Bitte versuche eine alternative Suchstrategie:\n"
|
|
"1. Wenn die Frage im Format 'X von Y' war (z.B. 'Lampen von Eaton'), "
|
|
"verwende IMMER eine Kombination aus Lieferanten-Filter (WHERE a.\"Lieferant\" LIKE '%Y%') "
|
|
"UND Produkttyp-Filter (WHERE a.\"Artikelbezeichnung\" LIKE '%X%' OR ...)\n"
|
|
"2. Verwende mehrere Synonyme für den Produkttyp (z.B. bei 'Lampen': Lampe, LED, Beleuchtung, Licht, Leuchte, Strahler)\n"
|
|
"3. Führe zuerst eine COUNT-Abfrage durch, dann die Detail-Abfrage mit Lagerbeständen\n"
|
|
"4. Verwende LIKE '%Lieferant%' für den Lieferanten-Filter, um auch Varianten zu finden"
|
|
)
|
|
result["messages"].append(retry_message)
|
|
break
|
|
|
|
return result
|
|
|
|
# Compose the workflow
|
|
workflow = StateGraph(ChatState)
|
|
workflow.add_node("agent", agent_node)
|
|
workflow.add_node("tools", tools_with_retry)
|
|
workflow.add_edge(START, "agent")
|
|
workflow.add_conditional_edges("agent", should_continue)
|
|
workflow.add_edge("tools", "agent")
|
|
return workflow.compile(checkpointer=memory)
|
|
|
|
async def chat(self, message: str, chat_id: str = "default") -> List[BaseMessage]:
|
|
"""Processes a chat message by calling the LLM and tools and returns the chat history.
|
|
|
|
Args:
|
|
message: The user message to process.
|
|
chat_id: The chat thread ID.
|
|
|
|
Returns:
|
|
The list of messages in the chat history.
|
|
"""
|
|
# Set the right thread ID for memory
|
|
config = {"configurable": {"thread_id": chat_id}}
|
|
|
|
# Single-turn chat (non-streaming)
|
|
result = await self.app.ainvoke(
|
|
{"messages": [HumanMessage(content=message)]}, config=config
|
|
)
|
|
|
|
# Extract and return the messages from the result
|
|
return result["messages"]
|
|
|
|
async def stream_events(
|
|
self, *, message: str, chat_id: str = "default"
|
|
) -> AsyncIterator[dict]:
|
|
"""Stream UI-focused events using astream_events v2.
|
|
|
|
Args:
|
|
message: The user message to process.
|
|
chat_id: Logical thread identifier; forwarded in the runnable config so
|
|
memory and tools are scoped per thread.
|
|
|
|
Yields:
|
|
dict: One of:
|
|
- ``{"type": "status", "label": str}`` for short progress updates.
|
|
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
|
|
where ``chat_history`` only includes ``user``/``assistant`` roles.
|
|
- ``{"type": "error", "message": str}`` if an exception occurs.
|
|
"""
|
|
# Thread-aware config for LangGraph/LangChain
|
|
config = {"configurable": {"thread_id": chat_id}}
|
|
|
|
def _is_root(ev: dict) -> bool:
|
|
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
|
return not ev.get("parent_ids")
|
|
|
|
try:
|
|
async for event in self.app.astream_events(
|
|
{"messages": [HumanMessage(content=message)]},
|
|
config=config,
|
|
version="v2",
|
|
):
|
|
etype = event.get("event")
|
|
ename = event.get("name") or ""
|
|
edata = event.get("data") or {}
|
|
|
|
# Stream human-readable progress via the special send_streaming_message tool
|
|
# Match the legacy implementation exactly (line 267-272 in legacy/chatbot.py)
|
|
if etype == "on_tool_start":
|
|
# Log all tool starts to debug
|
|
logger.debug(f"Tool start event: name='{ename}', event='{etype}'")
|
|
if ename == "send_streaming_message":
|
|
tool_in = edata.get("input") or {}
|
|
msg = tool_in.get("message")
|
|
logger.info(f"send_streaming_message tool called with input: {tool_in}")
|
|
if isinstance(msg, str) and msg.strip():
|
|
logger.info(f"Status-Update gesendet: {msg.strip()}")
|
|
yield {"type": "status", "label": msg.strip()}
|
|
continue
|
|
|
|
# Emit the final payload when the root run finishes
|
|
if etype == "on_chain_end" and _is_root(event):
|
|
output_obj = edata.get("output")
|
|
|
|
# Extract message list from the graph's final output
|
|
final_msgs = ChatStreamingHelper.extract_messages_from_output(
|
|
output_obj=output_obj
|
|
)
|
|
|
|
# Normalize for the frontend (only user/assistant with text content)
|
|
chat_history_payload: List[dict] = []
|
|
for m in final_msgs:
|
|
if isinstance(m, BaseMessage):
|
|
d = ChatStreamingHelper.message_to_dict(msg=m)
|
|
elif isinstance(m, dict):
|
|
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
|
|
else:
|
|
continue
|
|
if d.get("role") in ("user", "assistant") and d.get("content"):
|
|
chat_history_payload.append(d)
|
|
|
|
yield {
|
|
"type": "final",
|
|
"response": {
|
|
"thread": chat_id,
|
|
"chat_history": chat_history_payload,
|
|
},
|
|
}
|
|
return
|
|
|
|
except Exception as exc:
|
|
# Emit a single error envelope and end the stream
|
|
logger.error(f"Exception in stream_events: {exc}", exc_info=True)
|
|
yield {"type": "error", "message": f"Fehler beim Verarbeiten: {exc}"}
|