# Copyright (c) 2025 Patrick Motsch # All rights reserved. """Chatbot domain logic.""" import logging from dataclasses import dataclass, field from typing import Annotated, AsyncIterator, Any, List, Optional, TYPE_CHECKING from pydantic import BaseModel from langchain_core.messages import ( BaseMessage, HumanMessage, SystemMessage, ToolMessage, trim_messages, ) from langgraph.graph.message import add_messages from langgraph.graph import StateGraph, START, END from langgraph.graph.state import CompiledStateGraph from langgraph.prebuilt import ToolNode from modules.features.chatbot.bridges.ai import AICenterChatModel from modules.features.chatbot.bridges.memory import DatabaseCheckpointer from modules.features.chatbot.bridges.tools import ( create_sql_query_tool, create_tavily_search_tool, create_send_streaming_message_tool, ) from modules.features.chatbot.streaming.helpers import ChatStreamingHelper from modules.features.chatbot.streaming.events import get_event_manager from modules.datamodels.datamodelUam import User if TYPE_CHECKING: from modules.features.chatbot.config import ChatbotConfig logger = logging.getLogger(__name__) class ChatState(BaseModel): """Represents the state of a chat session.""" messages: Annotated[List[BaseMessage], add_messages] @dataclass class Chatbot: """Represents a chatbot.""" model: AICenterChatModel memory: DatabaseCheckpointer app: CompiledStateGraph = None system_prompt: str = "You are a helpful assistant." workflow_id: str = "default" config: Optional["ChatbotConfig"] = None @classmethod async def create( cls, model: AICenterChatModel, memory: DatabaseCheckpointer, system_prompt: str, workflow_id: str = "default", config: Optional["ChatbotConfig"] = None, ) -> "Chatbot": """Factory method to create and configure a Chatbot instance. Args: model: The chat model to use (AICenterChatModel). memory: The chat memory to use (DatabaseCheckpointer). system_prompt: The system prompt to initialize the chatbot. workflow_id: The workflow ID (maps to thread_id). config: Optional chatbot configuration for dynamic tool enablement. Returns: A configured Chatbot instance. """ instance = Chatbot( model=model, memory=memory, system_prompt=system_prompt, workflow_id=workflow_id, config=config, ) configured_tools = await instance._configure_tools() instance.app = instance._build_app(memory, configured_tools) return instance async def _configure_tools(self) -> List[Any]: """Configure tools for the chatbot based on config. Returns: List of configured tools based on config settings. """ tools = [] # Get tool enablement from config (use defaults if no config) sql_enabled = True tavily_enabled = False streaming_enabled = True connector_type = "preprocessor" if self.config: sql_enabled = self.config.tools.is_sql_enabled() tavily_enabled = self.config.tools.is_tavily_enabled() streaming_enabled = self.config.tools.is_streaming_enabled() connector_type = self.config.database.connector logger.info(f"Chatbot tools config - SQL: {sql_enabled}, Tavily: {tavily_enabled}, " f"Streaming: {streaming_enabled}, Connector: {connector_type}") # SQL query tool (if enabled) if sql_enabled: sql_tool = create_sql_query_tool(connector_type=connector_type) tools.append(sql_tool) logger.debug(f"Added SQL query tool with connector: {connector_type}") # Tavily search tool (if enabled) if tavily_enabled: tavily_tool = create_tavily_search_tool() tools.append(tavily_tool) logger.debug("Added Tavily search tool") # Streaming status tool (if enabled) if streaming_enabled: event_manager = get_event_manager() send_streaming_message = create_send_streaming_message_tool(event_manager) tools.append(send_streaming_message) logger.debug("Added streaming status tool") logger.info(f"Configured {len(tools)} tools for chatbot workflow {self.workflow_id}") return tools def _build_app( self, memory: DatabaseCheckpointer, tools: List[Any] ) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]: """Builds the chatbot application workflow using LangGraph. Args: memory: The chat memory to use. tools: The list of tools the chatbot can use. Returns: A compiled state graph representing the chatbot application. """ llm_with_tools = self.model.bind_tools(tools=tools) def select_window(msgs: List[BaseMessage]) -> List[BaseMessage]: """Selects a window of messages that fit within the context window size. Args: msgs: The list of messages to select from. Returns: A list of messages that fit within the context window size. """ def approx_counter(items: List[BaseMessage]) -> int: """Approximate token counter for messages. Args: items: List of messages to count tokens for. Returns: Approximate number of tokens in the messages. """ return sum(len(getattr(m, "content", "") or "") for m in items) # Use model's context length if available, otherwise default max_tokens = getattr(self.model._selected_model, "contextLength", 128000) if hasattr(self.model, "_selected_model") and self.model._selected_model else 128000 return trim_messages( msgs, strategy="last", token_counter=approx_counter, max_tokens=int(max_tokens * 0.8), # Use 80% of context window start_on="human", end_on=("human", "tool"), include_system=True, ) async def agent_node(state: ChatState) -> dict: """Agent node for the chatbot workflow. Args: state: The current chat state. Returns: The updated chat state after processing. """ # Select the message window to fit in context (trim if needed) window = select_window(state.messages) # Ensure the system prompt is present at the start if not window or not isinstance(window[0], SystemMessage): window = [SystemMessage(content=self.system_prompt)] + window # Call the LLM with tools (use ainvoke for async) response = await llm_with_tools.ainvoke(window) # Return the new state return {"messages": [response]} def should_continue(state: ChatState) -> str: """Determines whether to continue the workflow or end it. This conditional edge is called after the agent node to decide whether to continue to the tools node (if the last message contains tool calls) or to end the workflow (if no tool calls are present). Args: state: The current chat state. Returns: The next node to transition to ("tools" or END). """ # Get the last message last_message = state.messages[-1] # Check if the last message contains tool calls # If so, continue to the tools node; otherwise, end the workflow return "tools" if getattr(last_message, "tool_calls", None) else END async def tools_with_retry(state: ChatState) -> dict: """Tools node with parallel execution and retry logic. Args: state: The current chat state. Returns: The updated chat state after tool execution. """ import asyncio # Get tool calls from the last message last_message = state.messages[-1] tool_calls = getattr(last_message, "tool_calls", []) if not tool_calls: return {"messages": []} # Create a lookup for tools by name tools_by_name = {t.name: t for t in tools} async def execute_single_tool(tool_call): """Execute a single tool call.""" tool_name = tool_call.get("name") or tool_call.get("function", {}).get("name") tool_id = tool_call.get("id", f"call_{tool_name}") args = tool_call.get("args") or tool_call.get("function", {}).get("arguments", {}) if isinstance(args, str): import json try: args = json.loads(args) except: args = {"input": args} tool = tools_by_name.get(tool_name) if not tool: return ToolMessage( content=f"Error: Tool '{tool_name}' not found", tool_call_id=tool_id, name=tool_name ) try: # Execute tool asynchronously if asyncio.iscoroutinefunction(tool.coroutine): result = await tool.coroutine(**args) elif hasattr(tool, 'ainvoke'): result = await tool.ainvoke(args) else: result = tool.invoke(args) return ToolMessage( content=str(result), tool_call_id=tool_id, name=tool_name ) except Exception as e: logger.error(f"Tool {tool_name} failed: {e}") return ToolMessage( content=f"Error executing {tool_name}: {str(e)}", tool_call_id=tool_id, name=tool_name ) # Execute ALL tool calls in parallel logger.info(f"Executing {len(tool_calls)} tool calls in parallel") tool_messages = await asyncio.gather( *[execute_single_tool(tc) for tc in tool_calls], return_exceptions=True ) # Convert exceptions to error messages result_messages = [] for i, msg in enumerate(tool_messages): if isinstance(msg, Exception): tool_call = tool_calls[i] tool_name = tool_call.get("name", "unknown") tool_id = tool_call.get("id", f"call_{i}") result_messages.append(ToolMessage( content=f"Error: {str(msg)}", tool_call_id=tool_id, name=tool_name )) else: result_messages.append(msg) result = {"messages": result_messages} # Check if we got no results and should retry no_results_keywords = [ "returned 0 rows", "no data", "keine artikel gefunden", "keine ergebnisse" ] # Check tool results for no data for msg in result.get("messages", []): content = getattr(msg, "content", "") if isinstance(content, str): content_lower = content.lower() if any(keyword in content_lower for keyword in no_results_keywords): # Check if we haven't retried yet (avoid infinite loops) retry_count = sum(1 for m in state.messages if "retry" in str(getattr(m, "content", "")).lower()) if retry_count < 2: # Allow max 2 retries logger.info("No results found in tool output, adding retry instruction") retry_message = HumanMessage( content="WICHTIG: Die vorherige Suche hat keine Ergebnisse gefunden. " "Bitte versuche eine alternative Suchstrategie:\n" "1. Wenn die Frage im Format 'X von Y' war (z.B. 'Lampen von Eaton'), " "verwende IMMER eine Kombination aus Lieferanten-Filter (WHERE a.\"Lieferant\" LIKE '%Y%') " "UND Produkttyp-Filter (WHERE a.\"Artikelbezeichnung\" LIKE '%X%' OR ...)\n" "2. Verwende mehrere Synonyme für den Produkttyp (z.B. bei 'Lampen': Lampe, LED, Beleuchtung, Licht, Leuchte, Strahler)\n" "3. Führe zuerst eine COUNT-Abfrage durch, dann die Detail-Abfrage mit Lagerbeständen\n" "4. Verwende LIKE '%Lieferant%' für den Lieferanten-Filter, um auch Varianten zu finden" ) result["messages"].append(retry_message) break return result # Compose the workflow workflow = StateGraph(ChatState) workflow.add_node("agent", agent_node) workflow.add_node("tools", tools_with_retry) workflow.add_edge(START, "agent") workflow.add_conditional_edges("agent", should_continue) workflow.add_edge("tools", "agent") return workflow.compile(checkpointer=memory) async def chat(self, message: str, chat_id: str = "default") -> List[BaseMessage]: """Processes a chat message by calling the LLM and tools and returns the chat history. Args: message: The user message to process. chat_id: The chat thread ID. Returns: The list of messages in the chat history. """ # Set the right thread ID for memory config = {"configurable": {"thread_id": chat_id}} # Single-turn chat (non-streaming) result = await self.app.ainvoke( {"messages": [HumanMessage(content=message)]}, config=config ) # Extract and return the messages from the result return result["messages"] async def stream_events( self, *, message: str, chat_id: str = "default" ) -> AsyncIterator[dict]: """Stream UI-focused events using astream_events v2. Args: message: The user message to process. chat_id: Logical thread identifier; forwarded in the runnable config so memory and tools are scoped per thread. Yields: dict: One of: - ``{"type": "status", "label": str}`` for short progress updates. - ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}`` where ``chat_history`` only includes ``user``/``assistant`` roles. - ``{"type": "error", "message": str}`` if an exception occurs. """ # Thread-aware config for LangGraph/LangChain config = {"configurable": {"thread_id": chat_id}} def _is_root(ev: dict) -> bool: """Return True if the event is from the root run (v2: empty parent_ids).""" return not ev.get("parent_ids") try: async for event in self.app.astream_events( {"messages": [HumanMessage(content=message)]}, config=config, version="v2", ): etype = event.get("event") ename = event.get("name") or "" edata = event.get("data") or {} # Stream human-readable progress via the special send_streaming_message tool # Match the legacy implementation exactly (line 267-272 in legacy/chatbot.py) if etype == "on_tool_start": # Log all tool starts to debug logger.debug(f"Tool start event: name='{ename}', event='{etype}'") if ename == "send_streaming_message": tool_in = edata.get("input") or {} msg = tool_in.get("message") logger.info(f"send_streaming_message tool called with input: {tool_in}") if isinstance(msg, str) and msg.strip(): logger.info(f"Status-Update gesendet: {msg.strip()}") yield {"type": "status", "label": msg.strip()} continue # Emit the final payload when the root run finishes if etype == "on_chain_end" and _is_root(event): output_obj = edata.get("output") # Extract message list from the graph's final output final_msgs = ChatStreamingHelper.extract_messages_from_output( output_obj=output_obj ) # Normalize for the frontend (only user/assistant with text content) chat_history_payload: List[dict] = [] for m in final_msgs: if isinstance(m, BaseMessage): d = ChatStreamingHelper.message_to_dict(msg=m) elif isinstance(m, dict): d = ChatStreamingHelper.dict_message_to_dict(obj=m) else: continue if d.get("role") in ("user", "assistant") and d.get("content"): chat_history_payload.append(d) yield { "type": "final", "response": { "thread": chat_id, "chat_history": chat_history_payload, }, } return except Exception as exc: # Emit a single error envelope and end the stream logger.error(f"Exception in stream_events: {exc}", exc_info=True) yield {"type": "error", "message": f"Fehler beim Verarbeiten: {exc}"}