833 lines
39 KiB
Python
833 lines
39 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Chatbot domain logic."""
|
|
|
|
import re
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from typing import Annotated, AsyncIterator, Any, List, Optional, TYPE_CHECKING
|
|
from pydantic import BaseModel
|
|
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
trim_messages,
|
|
)
|
|
from langgraph.graph.message import add_messages
|
|
from langgraph.graph import StateGraph, START, END
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
|
|
from modules.features.chatbot.bridges.ai import AICenterChatModel
|
|
from modules.features.chatbot.bridges.memory import DatabaseCheckpointer
|
|
from modules.features.chatbot.bridges.tools import (
|
|
create_sql_query_tool,
|
|
create_tavily_search_tool,
|
|
create_send_streaming_message_tool,
|
|
)
|
|
from modules.features.chatbot.streaming.helpers import ChatStreamingHelper
|
|
from modules.features.chatbot.streaming.events import get_event_manager
|
|
from modules.datamodels.datamodelUam import User
|
|
|
|
if TYPE_CHECKING:
|
|
from modules.features.chatbot.config import ChatbotConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _tool_output_to_markdown_table(raw: str) -> str:
|
|
"""
|
|
Convert sqlite_query tool output to a markdown table for deterministic display.
|
|
Reduces model hallucination by providing a ready-to-copy table.
|
|
Format: "Query returned N rows:\\nColumns: A, B, C\\n1. A: x, B: y, C: z\\n..."
|
|
"""
|
|
if not raw or not raw.strip():
|
|
return raw
|
|
lines = [ln.strip() for ln in raw.strip().split("\n") if ln.strip()]
|
|
if len(lines) < 2:
|
|
return raw
|
|
# Parse header
|
|
row_count_line = lines[0] # "Query returned 20 rows:"
|
|
cols_line = next((ln for ln in lines if ln.lower().startswith("columns:")), None)
|
|
if not cols_line:
|
|
return raw
|
|
headers = [h.strip() for h in cols_line.replace("Columns:", "").split(",")]
|
|
if not headers:
|
|
return raw
|
|
# Parse data rows (1. Col: val, Col: val)
|
|
rows = []
|
|
for ln in lines:
|
|
if re.match(r"^\d+\.\s+", ln):
|
|
rest = re.sub(r"^\d+\.\s+", "", ln)
|
|
row = {}
|
|
for part in rest.split(", "):
|
|
if ": " in part:
|
|
k, v = part.split(": ", 1)
|
|
row[k.strip()] = str(v).strip()
|
|
if row:
|
|
rows.append([row.get(h, "") for h in headers])
|
|
if not rows:
|
|
return raw
|
|
# Build markdown table
|
|
sep = " | "
|
|
header_row = sep.join(headers)
|
|
div_row = sep.join(["---"] * len(headers))
|
|
data_rows = [sep.join(str(c) for c in r) for r in rows]
|
|
table = "\n".join([header_row, div_row] + data_rows)
|
|
suffix = ""
|
|
if "(Showing first" in raw or "of " in raw:
|
|
m = re.search(r"\(Showing first (\d+) of (\d+) rows\)", raw)
|
|
if m:
|
|
suffix = f"\n\nZeige {m.group(1)} von {m.group(2)} Artikeln."
|
|
return f"{row_count_line}\n\n{table}{suffix}"
|
|
|
|
|
|
def _sanitize_llm_response(text: str) -> str:
|
|
"""Strip chat template tokens and trailing junk that some models leak."""
|
|
if not text or not isinstance(text, str):
|
|
return text or ""
|
|
for sentinel in ("<|im_start|>", "<|im_end|>", "<|endoftext|>", "<|user|>", "<|assistant|>"):
|
|
if sentinel in text:
|
|
text = text.split(sentinel)[0]
|
|
return text.strip()
|
|
|
|
|
|
# Natural language markers to split system prompt into context sections
|
|
_SPLIT_MARKERS = {
|
|
"schema_start": "Die Datenbank enthält",
|
|
"schema_start_alt": "Die Datenbank enthält die Tabellen",
|
|
"response_structure_start": "Antwortstruktur ist strikt",
|
|
"response_structure_alt": "Antwortstruktur:",
|
|
"response_structure_fallback": "Antwortstruktur",
|
|
}
|
|
|
|
|
|
def _split_system_prompt(prompt: str) -> dict:
|
|
"""
|
|
Split system prompt by natural language section markers.
|
|
Returns: {intro, schema, response_structure}
|
|
- intro: Role, tools, general instructions (before schema)
|
|
- schema: Database tables, SQL rules, column definitions (for SQL generation)
|
|
- response_structure: Mandatory answer format (Einleitungssatz, Tabelle, etc.)
|
|
"""
|
|
if not prompt or not isinstance(prompt, str):
|
|
return {"intro": "", "schema": "", "response_structure": ""}
|
|
|
|
text = prompt.strip()
|
|
intro_end = len(text)
|
|
schema_start_idx = -1
|
|
schema_end = len(text)
|
|
response_start_idx = -1
|
|
|
|
# Find schema start
|
|
for marker in (_SPLIT_MARKERS["schema_start"], _SPLIT_MARKERS["schema_start_alt"]):
|
|
idx = text.find(marker)
|
|
if idx >= 0:
|
|
schema_start_idx = idx
|
|
intro_end = idx
|
|
break
|
|
|
|
# Find response structure start
|
|
for marker in (
|
|
_SPLIT_MARKERS["response_structure_start"],
|
|
_SPLIT_MARKERS["response_structure_alt"],
|
|
_SPLIT_MARKERS["response_structure_fallback"],
|
|
):
|
|
idx = text.find(marker)
|
|
if idx >= 0:
|
|
response_start_idx = idx
|
|
schema_end = idx if schema_start_idx >= 0 else len(text)
|
|
break
|
|
|
|
intro = text[:intro_end].strip() if intro_end > 0 else ""
|
|
schema = (
|
|
text[schema_start_idx:schema_end].strip()
|
|
if schema_start_idx >= 0 and schema_end > schema_start_idx
|
|
else ""
|
|
)
|
|
response_structure = (
|
|
text[response_start_idx:].strip()
|
|
if response_start_idx >= 0
|
|
else ""
|
|
)
|
|
|
|
# Fallback: if no markers found, use full prompt for intro
|
|
if not intro and not schema and not response_structure:
|
|
intro = text
|
|
elif not response_structure and intro:
|
|
response_structure = intro # Use intro's format hints as fallback
|
|
|
|
return {"intro": intro, "schema": schema, "response_structure": response_structure}
|
|
|
|
|
|
class ChatState(BaseModel):
|
|
"""Represents the state of a chat session."""
|
|
|
|
messages: Annotated[List[BaseMessage], add_messages]
|
|
plan: Optional[str] = None # Planner routing: "SQL", "TAVILY", "BOTH", "NONE"
|
|
|
|
|
|
@dataclass
|
|
class Chatbot:
|
|
"""Represents a chatbot."""
|
|
|
|
model: AICenterChatModel
|
|
memory: DatabaseCheckpointer
|
|
app: CompiledStateGraph = None
|
|
system_prompt: str = "You are a helpful assistant."
|
|
workflow_id: str = "default"
|
|
config: Optional["ChatbotConfig"] = None
|
|
|
|
@classmethod
|
|
async def create(
|
|
cls,
|
|
model: AICenterChatModel,
|
|
memory: DatabaseCheckpointer,
|
|
system_prompt: str,
|
|
workflow_id: str = "default",
|
|
config: Optional["ChatbotConfig"] = None,
|
|
) -> "Chatbot":
|
|
"""Factory method to create and configure a Chatbot instance.
|
|
|
|
Args:
|
|
model: The chat model to use (AICenterChatModel).
|
|
memory: The chat memory to use (DatabaseCheckpointer).
|
|
system_prompt: The system prompt to initialize the chatbot.
|
|
workflow_id: The workflow ID (maps to thread_id).
|
|
config: Optional chatbot configuration for dynamic tool enablement.
|
|
|
|
Returns:
|
|
A configured Chatbot instance.
|
|
"""
|
|
instance = Chatbot(
|
|
model=model,
|
|
memory=memory,
|
|
system_prompt=system_prompt,
|
|
workflow_id=workflow_id,
|
|
config=config,
|
|
)
|
|
configured_tools = await instance._configure_tools()
|
|
instance.app = instance._build_app(memory, configured_tools)
|
|
return instance
|
|
|
|
async def _configure_tools(self) -> List[Any]:
|
|
"""Configure tools for the chatbot based on config.
|
|
|
|
Returns:
|
|
List of configured tools based on config settings.
|
|
"""
|
|
tools = []
|
|
|
|
# Get tool enablement from config (use defaults if no config)
|
|
sql_enabled = True
|
|
tavily_enabled = False
|
|
streaming_enabled = True
|
|
connector_type = "preprocessor"
|
|
|
|
if self.config:
|
|
sql_enabled = self.config.tools.is_sql_enabled()
|
|
tavily_enabled = self.config.tools.is_tavily_enabled()
|
|
streaming_enabled = self.config.tools.is_streaming_enabled()
|
|
connector_type = self.config.database.connector
|
|
|
|
logger.info(f"Chatbot tools config - SQL: {sql_enabled}, Tavily: {tavily_enabled}, "
|
|
f"Streaming: {streaming_enabled}, Connector: {connector_type}")
|
|
|
|
# SQL query tool (if enabled)
|
|
if sql_enabled:
|
|
sql_tool = create_sql_query_tool(connector_type=connector_type)
|
|
tools.append(sql_tool)
|
|
logger.debug(f"Added SQL query tool with connector: {connector_type}")
|
|
|
|
# Tavily search tool (if enabled)
|
|
if tavily_enabled:
|
|
tavily_tool = create_tavily_search_tool()
|
|
tools.append(tavily_tool)
|
|
logger.debug("Added Tavily search tool")
|
|
|
|
# Streaming status tool (if enabled)
|
|
if streaming_enabled:
|
|
event_manager = get_event_manager()
|
|
send_streaming_message = create_send_streaming_message_tool(event_manager)
|
|
tools.append(send_streaming_message)
|
|
logger.debug("Added streaming status tool")
|
|
|
|
logger.info(f"Configured {len(tools)} tools for chatbot workflow {self.workflow_id}")
|
|
return tools
|
|
|
|
def _build_app(
|
|
self, memory: DatabaseCheckpointer, tools: List[Any]
|
|
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
|
|
"""Builds the chatbot application workflow using LangGraph.
|
|
|
|
Supports small context windows via planning phase and tiered prompts.
|
|
|
|
Args:
|
|
memory: The chat memory to use.
|
|
tools: The list of tools the chatbot can use.
|
|
|
|
Returns:
|
|
A compiled state graph representing the chatbot application.
|
|
"""
|
|
# Build tool subsets per agent type
|
|
tools_by_name = {t.name: t for t in tools}
|
|
sql_tool = tools_by_name.get("sqlite_query")
|
|
tavily_tool = tools_by_name.get("tavily_search")
|
|
streaming_tool = tools_by_name.get("send_streaming_message")
|
|
tools_sql = [t for t in [sql_tool, tavily_tool, streaming_tool] if t is not None]
|
|
tools_tavily = [t for t in [tavily_tool, streaming_tool] if t is not None]
|
|
llm_plain = self.model
|
|
# SQL path uses structured prompts + parse/execute (no native tool calling) - fits /api/analyze
|
|
llm_tavily = self.model.bind_tools(tools=tools_tavily) if tools_tavily else self.model
|
|
|
|
# Minimal planner prompt (~250 tokens) - fits any 8K+ model
|
|
# Explicit: Lager, Bestand, Artikel, wie viele = SQL (Datenbank)
|
|
PLANNER_SYSTEM = (
|
|
"Du bist ein Assistent. Antworte NUR mit einem Wort: SQL, TAVILY, BOTH oder NONE.\n"
|
|
"SQL = Fragen zu Lager, Bestand, Artikel, Preisen, wie viele, Anzahl (Datenbankabfrage).\n"
|
|
"TAVILY = Internetsuche, Produktinfos außerhalb der DB, Markttrends.\n"
|
|
"BOTH = beides nötig. NONE = nur Begrüßung oder Danksagung, keine Daten nötig.\n"
|
|
"Beispiele: 'wie viele X auf Lager' -> SQL, 'Infos zu Produkt Y' -> TAVILY."
|
|
)
|
|
|
|
# Truncation suffix for schema when prompt is cut
|
|
SCHEMA_TRUNCATION_SUFFIX = (
|
|
"\n\n[... Schema gekürzt. Wichtige Tabellen: Artikel, Lagerplatz_Artikel, Einkaufspreis, Lagerplatz. "
|
|
"Artikel-Spalte: a.\"Artikelbezeichnung\". "
|
|
"JOIN: Artikel a, Lagerplatz_Artikel l ON a.I_ID = l.R_ARTIKEL, Lagerplatz lp ON l.R_LAGERPLATZ = lp.I_ID.]"
|
|
)
|
|
|
|
# Structured output for /api/analyze (no tool calls): model outputs SQL in code block, we parse and execute
|
|
SQL_PLAN_SUFFIX = (
|
|
"\n\n--- AUSGABEFORMAT (PFLICHT) ---\n"
|
|
"Antworte NUR mit einer SQL SELECT-Abfrage in diesem Format:\n"
|
|
"```sql\nDEINE_SQL_QUERY\n```\n"
|
|
"KRITISCH bei 'wie viele X auf Lager': Liefere ARTIKELZEILEN (Artikelnummer, Artikelbezeichnung, Bestand) "
|
|
"mit LIMIT 20, NICHT nur SELECT COUNT(*). Ohne Detailzeilen kann keine Tabelle angezeigt werden. "
|
|
"Gesamtanzahl optional via Unterabfrage im SELECT."
|
|
)
|
|
|
|
bytes_per_token = 3 # Balanced estimate for mixed content
|
|
reserved_tokens = 3000 # Tools block + conversation overhead
|
|
|
|
def _get_context_length() -> int:
|
|
"""Get selected model's context length; pre-select if needed."""
|
|
if hasattr(self.model, "_selected_model") and self.model._selected_model:
|
|
return getattr(self.model._selected_model, "contextLength", 128000)
|
|
return 128000
|
|
|
|
def _truncate_system_prompt(full_prompt: str, max_chars: int, suffix: str = "") -> str:
|
|
"""Truncate system prompt to fit context budget."""
|
|
if len(full_prompt) <= max_chars:
|
|
return full_prompt
|
|
return full_prompt[: max_chars - len(suffix)] + suffix
|
|
|
|
# Split system prompt by natural language sections for targeted context
|
|
_prompt_sections = _split_system_prompt(self.system_prompt)
|
|
|
|
def select_window(msgs: List[BaseMessage], max_tokens_override: Optional[int] = None) -> List[BaseMessage]:
|
|
"""Selects a window of messages that fit within the context window size."""
|
|
|
|
def approx_counter(items: List[BaseMessage]) -> int:
|
|
return sum(len(getattr(m, "content", "") or "") for m in items)
|
|
|
|
max_tokens = max_tokens_override or _get_context_length()
|
|
return trim_messages(
|
|
msgs,
|
|
strategy="last",
|
|
token_counter=approx_counter,
|
|
max_tokens=int(max_tokens * 0.8),
|
|
start_on="human",
|
|
end_on=("human", "tool"),
|
|
include_system=True,
|
|
)
|
|
|
|
async def planner_node(state: ChatState) -> dict:
|
|
"""Planner: minimal prompt, no tools. Outputs SQL/TAVILY/BOTH/NONE.
|
|
Does NOT add planner message to chat - only sets state.plan for routing."""
|
|
human_msgs = [m for m in state.messages if isinstance(m, HumanMessage)]
|
|
last_human = human_msgs[-1].content if human_msgs else ""
|
|
window = [
|
|
SystemMessage(content=PLANNER_SYSTEM),
|
|
HumanMessage(content=last_human),
|
|
]
|
|
plan = "SQL"
|
|
try:
|
|
response = await llm_plain.ainvoke(window)
|
|
except ValueError as exc:
|
|
if "No suitable model found" in str(exc):
|
|
logger.warning(f"Planner model selection failed: {exc}")
|
|
return {"plan": plan}
|
|
raise
|
|
content = (response.content or "").strip().upper()
|
|
for keyword in ("SQL", "TAVILY", "BOTH", "NONE"):
|
|
if keyword in content:
|
|
plan = keyword
|
|
break
|
|
return {"plan": plan}
|
|
|
|
# Keywords that indicate database/inventory query - override NONE to SQL
|
|
_SQL_KEYWORDS = (
|
|
"lager", "bestand", "artikel", "wie viele", "anzahl", "preis",
|
|
"lieferant", "lieferanten", "bestellen", "verfügbar", "inventar"
|
|
)
|
|
|
|
def route_by_plan(state: ChatState) -> str:
|
|
"""Route from planner to agent_sql_plan, agent_tavily, or agent_answer."""
|
|
plan = (state.plan or "SQL").upper()
|
|
# Override NONE when user clearly asks for inventory/data (e.g. "wie viele LEDs auf Lager")
|
|
if plan == "NONE" and sql_tool:
|
|
last_user = ""
|
|
for m in reversed(state.messages):
|
|
if isinstance(m, HumanMessage):
|
|
last_user = (m.content or "").lower()
|
|
break
|
|
if any(kw in last_user for kw in _SQL_KEYWORDS):
|
|
logger.info("Planner returned NONE but user asked inventory question - routing to SQL")
|
|
plan = "SQL"
|
|
if plan in ("SQL", "BOTH") and sql_tool:
|
|
return "agent_sql_plan"
|
|
if plan == "TAVILY" and tavily_tool:
|
|
return "agent_tavily"
|
|
return "agent_answer"
|
|
|
|
async def _agent_common(
|
|
state: ChatState,
|
|
system_content: str,
|
|
llm,
|
|
node_name: str,
|
|
) -> dict:
|
|
"""Shared logic for agent nodes."""
|
|
msgs = select_window(state.messages)
|
|
if not msgs or not isinstance(msgs[0], SystemMessage):
|
|
window = [SystemMessage(content=system_content)] + msgs
|
|
else:
|
|
window = [SystemMessage(content=system_content)] + [m for m in msgs if not isinstance(m, SystemMessage)]
|
|
try:
|
|
response = await llm.ainvoke(window)
|
|
except ValueError as exc:
|
|
if "No suitable model found" in str(exc):
|
|
logger.warning(f"{node_name} model selection failed: {exc}")
|
|
response = AIMessage(
|
|
content=(
|
|
"Es tut mir leid, derzeit steht kein passendes KI-Modell für diese Anfrage zur Verfügung. "
|
|
"Bitte versuchen Sie es später erneut oder wenden Sie sich an den Administrator."
|
|
)
|
|
)
|
|
else:
|
|
raise
|
|
return {"messages": [response]}
|
|
|
|
async def agent_sql_plan_node(state: ChatState) -> dict:
|
|
"""Generate SQL. Uses schema section + minimal intro. Output: ```sql...``` for parse/execute."""
|
|
ctx_len = _get_context_length()
|
|
max_system_chars = max(1000, int(ctx_len * 0.8 - reserved_tokens) * bytes_per_token) - len(SQL_PLAN_SUFFIX)
|
|
# Prefer schema section; add short intro if space allows
|
|
schema_part = _prompt_sections["schema"] or _prompt_sections["intro"]
|
|
intro_part = _prompt_sections["intro"][:400] if _prompt_sections["intro"] else ""
|
|
combined = f"{intro_part}\n\n{schema_part}" if intro_part else schema_part
|
|
system_content = _truncate_system_prompt(
|
|
combined, max_system_chars, SCHEMA_TRUNCATION_SUFFIX
|
|
) + SQL_PLAN_SUFFIX
|
|
return await _agent_common(state, system_content, llm_plain, "agent_sql_plan")
|
|
|
|
def _parse_sql_from_content(content: str) -> Optional[str]:
|
|
"""Extract SQL from ```sql...``` or ```...``` code block. Only allows SELECT."""
|
|
if not content:
|
|
return None
|
|
match = re.search(r"```(?:sql)?\s*([\s\S]*?)```", content)
|
|
if match:
|
|
sql = match.group(1).strip()
|
|
if sql and sql.upper().strip().startswith("SELECT"):
|
|
return sql
|
|
# Fallback: find line starting with SELECT
|
|
for line in content.split("\n"):
|
|
line = line.strip()
|
|
if line.upper().startswith("SELECT"):
|
|
return line
|
|
return None
|
|
|
|
def _sanitize_sql_typos(sql: str) -> str:
|
|
"""Fix common LLM SQL typos that cause syntax errors."""
|
|
if not sql:
|
|
return sql
|
|
# Fix "CASE WHENLAGerplatz" - missing space after WHEN when followed directly by identifier
|
|
sql = re.sub(r"WHEN([A-Za-z_][A-Za-z0-9_.\"]*)", r"WHEN \1", sql, flags=re.IGNORECASE)
|
|
# Fix "LAGerplatz_Artikel" / "LAGerplatz" -> correct casing
|
|
sql = re.sub(r"\bLAGerplatz_Artikel\b", "Lagerplatz_Artikel", sql)
|
|
sql = re.sub(r"\bLAGerplatz\b", "Lagerplatz", sql)
|
|
# Preprocessor uses Einkaufspreis (not Einkaufspreis_neu) and m_Artikel (not ARTIKEL)
|
|
sql = sql.replace('"Einkaufspreis_neu"', '"Einkaufspreis"')
|
|
sql = sql.replace("Einkaufspreis_neu.", "Einkaufspreis.")
|
|
sql = re.sub(
|
|
r'"Einkaufspreis"\."ARTIKEL"',
|
|
'"Einkaufspreis"."m_Artikel"',
|
|
sql,
|
|
)
|
|
return sql
|
|
|
|
async def parse_execute_sql_node(state: ChatState) -> dict:
|
|
"""Parse SQL from last AIMessage, execute via preprocessor, add ToolMessage."""
|
|
last_msg = state.messages[-1] if state.messages else None
|
|
if not isinstance(last_msg, AIMessage):
|
|
return {"messages": [ToolMessage(content="Fehler: Keine AI-Antwort zum Parsen.", tool_call_id="parse_0", name="sqlite_query")]}
|
|
sql = _parse_sql_from_content(last_msg.content or "")
|
|
if not sql or not sql_tool:
|
|
return {"messages": [ToolMessage(content="Konnte keine SQL-Abfrage aus der Antwort extrahieren.", tool_call_id="parse_0", name="sqlite_query")]}
|
|
sql = _sanitize_sql_typos(sql)
|
|
try:
|
|
result = await sql_tool.ainvoke({"query": sql})
|
|
except Exception as e:
|
|
logger.error(f"SQL execution failed: {e}")
|
|
result = f"Fehler bei der Ausführung: {e}"
|
|
return {"messages": [ToolMessage(content=str(result), tool_call_id="parse_0", name="sqlite_query")]}
|
|
|
|
FORMULATE_TASK = (
|
|
"\n\n--- AKTUELLE AUFGABE ---\n"
|
|
"Du erhältst eine Benutzerfrage und die exakten Datenbankergebnisse. "
|
|
"KRITISCH: Nutze NUR die gelieferten Daten. Erfinde NIEMALS Daten (keine LED-A01, LED Rot, etc.). "
|
|
"Wenn die Ergebnisse NUR eine Zahl enthalten (z.B. '1. COUNT(*): 806'): Reportiere NUR diese Zahl, KEINE erfundene Tabelle. "
|
|
"Eine Tabelle darf NUR erstellt werden, wenn echte Zeilen '1. Spalte: Wert, ...' in den Daten stehen. "
|
|
"Beachte die obige ANTWORTSTRUKTUR."
|
|
)
|
|
|
|
async def agent_formulate_node(state: ChatState) -> dict:
|
|
"""Formulate final answer. Uses intro + response_structure sections (not schema)."""
|
|
human_content = ""
|
|
tool_content = ""
|
|
for m in state.messages:
|
|
if isinstance(m, HumanMessage):
|
|
human_content = m.content or ""
|
|
if isinstance(m, ToolMessage) and getattr(m, "name", "") == "sqlite_query":
|
|
tool_content = m.content or ""
|
|
if not tool_content or not tool_content.strip():
|
|
logger.warning("agent_formulate: no tool_content (sqlite_query) in state.messages")
|
|
return {"messages": [AIMessage(content="Die Datenbankabfrage konnte keine Ergebnisse liefern. Bitte versuchen Sie es mit einer anderen Formulierung.")]}
|
|
# When SQL failed, return error directly - don't let model hallucinate success
|
|
if "Query failed" in tool_content or tool_content.strip().startswith("Error"):
|
|
err_summary = "Die Datenbankabfrage ist fehlgeschlagen."
|
|
if "no such column" in tool_content:
|
|
err_summary += " Ein Spaltenname scheint nicht zu passen. Bitte die Anfrage anders formulieren."
|
|
return {"messages": [AIMessage(content=err_summary)]}
|
|
# Convert to markdown table so model copies exact values instead of reformatting/hallucinating
|
|
formatted_data = _tool_output_to_markdown_table(tool_content)
|
|
logger.debug(f"agent_formulate: tool_content length={len(tool_content)}, formatted={len(formatted_data)}")
|
|
ctx_len = _get_context_length()
|
|
max_system_chars = max(3000, int(ctx_len * 0.5) * bytes_per_token) - len(FORMULATE_TASK)
|
|
# Use intro + response_structure (mandatory format)
|
|
resp_struct = _prompt_sections["response_structure"] or _prompt_sections["intro"]
|
|
intro_formulate = _prompt_sections["intro"]
|
|
combined = f"{intro_formulate}\n\n{resp_struct}" if intro_formulate != resp_struct else resp_struct
|
|
# Fit within context; prefer keeping response_structure intact
|
|
if len(combined) + len(FORMULATE_TASK) > max_system_chars:
|
|
combined = _truncate_system_prompt(combined, max_system_chars - len(FORMULATE_TASK), "")
|
|
system_content = combined + FORMULATE_TASK
|
|
prompt = (
|
|
f"Benutzerfrage: {human_content}\n\n"
|
|
"--- VORGEGEBENE DATEN (diese Tabelle/Zahlen UNVERÄNDERT in die Antwort übernehmen): ---\n"
|
|
f"{formatted_data}\n\n"
|
|
"Die obige Tabelle bzw. Zahlen sind die EINZIGEN erlaubten Daten. Kopiere sie 1:1. "
|
|
"Berechne keine eigenen Summen/Anzahlen - nutze die gelieferten Werte. Formuliere die Antwort:"
|
|
)
|
|
window = [SystemMessage(content=system_content), HumanMessage(content=prompt)]
|
|
try:
|
|
response = await llm_plain.ainvoke(window)
|
|
except ValueError as exc:
|
|
if "No suitable model found" in str(exc):
|
|
response = AIMessage(content="Es gab einen Fehler bei der Formulierung. Bitte versuchen Sie es erneut.")
|
|
else:
|
|
raise
|
|
# Sanitize: strip leaked chat template tokens (<|im_start|> etc.) and trailing junk
|
|
if response.content:
|
|
response = AIMessage(content=_sanitize_llm_response(response.content))
|
|
return {"messages": [response]}
|
|
|
|
async def agent_tavily_node(state: ChatState) -> dict:
|
|
"""Agent with Tavily only. Uses intro + response_structure (no schema)."""
|
|
resp_struct = _prompt_sections["response_structure"] or ""
|
|
intro_tavily = _prompt_sections["intro"]
|
|
combined = f"{intro_tavily}\n\n{resp_struct}" if resp_struct else intro_tavily
|
|
system_content = _truncate_system_prompt(combined, 6000, "")
|
|
return await _agent_common(state, system_content, llm_tavily, "agent_tavily")
|
|
|
|
async def agent_answer_node(state: ChatState) -> dict:
|
|
"""Agent with no tools. Uses intro + response_structure."""
|
|
resp_struct = _prompt_sections["response_structure"] or ""
|
|
intro_answer = _prompt_sections["intro"]
|
|
combined = f"{intro_answer}\n\n{resp_struct}" if resp_struct else intro_answer
|
|
system_content = _truncate_system_prompt(combined, 6000, "")
|
|
return await _agent_common(state, system_content, llm_plain, "agent_answer")
|
|
|
|
def should_continue_tavily(state: ChatState) -> str:
|
|
last = state.messages[-1]
|
|
return "tools" if getattr(last, "tool_calls", None) else END
|
|
|
|
def route_back(state: ChatState) -> str:
|
|
"""Route from tools back to agent_tavily (SQL path uses parse_execute_sql, no tools loop)."""
|
|
# Tools node is only reached from agent_tavily when it returns tool_calls
|
|
return "agent_tavily" if tavily_tool else "agent_answer"
|
|
|
|
async def tools_with_retry(state: ChatState) -> dict:
|
|
"""Tools node with parallel execution and retry logic.
|
|
|
|
Args:
|
|
state: The current chat state.
|
|
|
|
Returns:
|
|
The updated chat state after tool execution.
|
|
"""
|
|
import asyncio
|
|
|
|
# Get tool calls from the last message
|
|
last_message = state.messages[-1]
|
|
tool_calls = getattr(last_message, "tool_calls", [])
|
|
|
|
if not tool_calls:
|
|
return {"messages": []}
|
|
|
|
# Create a lookup for tools by name
|
|
tools_by_name = {t.name: t for t in tools}
|
|
|
|
async def execute_single_tool(tool_call):
|
|
"""Execute a single tool call."""
|
|
tool_name = tool_call.get("name") or tool_call.get("function", {}).get("name")
|
|
tool_id = tool_call.get("id", f"call_{tool_name}")
|
|
args = tool_call.get("args") or tool_call.get("function", {}).get("arguments", {})
|
|
|
|
if isinstance(args, str):
|
|
import json
|
|
try:
|
|
args = json.loads(args)
|
|
except:
|
|
args = {"input": args}
|
|
|
|
tool = tools_by_name.get(tool_name)
|
|
if not tool:
|
|
return ToolMessage(
|
|
content=f"Error: Tool '{tool_name}' not found",
|
|
tool_call_id=tool_id,
|
|
name=tool_name
|
|
)
|
|
|
|
try:
|
|
# Execute tool asynchronously
|
|
if asyncio.iscoroutinefunction(tool.coroutine):
|
|
result = await tool.coroutine(**args)
|
|
elif hasattr(tool, 'ainvoke'):
|
|
result = await tool.ainvoke(args)
|
|
else:
|
|
result = tool.invoke(args)
|
|
|
|
return ToolMessage(
|
|
content=str(result),
|
|
tool_call_id=tool_id,
|
|
name=tool_name
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Tool {tool_name} failed: {e}")
|
|
return ToolMessage(
|
|
content=f"Error executing {tool_name}: {str(e)}",
|
|
tool_call_id=tool_id,
|
|
name=tool_name
|
|
)
|
|
|
|
# Execute ALL tool calls in parallel
|
|
logger.info(f"Executing {len(tool_calls)} tool calls in parallel")
|
|
tool_messages = await asyncio.gather(
|
|
*[execute_single_tool(tc) for tc in tool_calls],
|
|
return_exceptions=True
|
|
)
|
|
|
|
# Convert exceptions to error messages
|
|
result_messages = []
|
|
for i, msg in enumerate(tool_messages):
|
|
if isinstance(msg, Exception):
|
|
tool_call = tool_calls[i]
|
|
tool_name = tool_call.get("name", "unknown")
|
|
tool_id = tool_call.get("id", f"call_{i}")
|
|
result_messages.append(ToolMessage(
|
|
content=f"Error: {str(msg)}",
|
|
tool_call_id=tool_id,
|
|
name=tool_name
|
|
))
|
|
else:
|
|
result_messages.append(msg)
|
|
|
|
result = {"messages": result_messages}
|
|
|
|
# Check if we got no results and should retry
|
|
no_results_keywords = [
|
|
"returned 0 rows",
|
|
"no data",
|
|
"keine artikel gefunden",
|
|
"keine ergebnisse"
|
|
]
|
|
|
|
# Check tool results for no data
|
|
for msg in result.get("messages", []):
|
|
content = getattr(msg, "content", "")
|
|
if isinstance(content, str):
|
|
content_lower = content.lower()
|
|
if any(keyword in content_lower for keyword in no_results_keywords):
|
|
# Check if we haven't retried yet (avoid infinite loops)
|
|
retry_count = sum(1 for m in state.messages if "retry" in str(getattr(m, "content", "")).lower())
|
|
if retry_count < 2: # Allow max 2 retries
|
|
logger.info("No results found in tool output, adding retry instruction")
|
|
retry_message = HumanMessage(
|
|
content="WICHTIG: Die vorherige Suche hat keine Ergebnisse gefunden. "
|
|
"Bitte versuche eine alternative Suchstrategie:\n"
|
|
"1. Wenn die Frage im Format 'X von Y' war (z.B. 'Lampen von Eaton'), "
|
|
"verwende IMMER eine Kombination aus Lieferanten-Filter (WHERE a.\"Lieferant\" LIKE '%Y%') "
|
|
"UND Produkttyp-Filter (WHERE a.\"Artikelbezeichnung\" LIKE '%X%' OR ...)\n"
|
|
"2. Verwende mehrere Synonyme für den Produkttyp (z.B. bei 'Lampen': Lampe, LED, Beleuchtung, Licht, Leuchte, Strahler)\n"
|
|
"3. Führe zuerst eine COUNT-Abfrage durch, dann die Detail-Abfrage mit Lagerbeständen\n"
|
|
"4. Verwende LIKE '%Lieferant%' für den Lieferanten-Filter, um auch Varianten zu finden"
|
|
)
|
|
result["messages"].append(retry_message)
|
|
break
|
|
|
|
return result
|
|
|
|
# Compose the workflow: planner -> route -> agent_* -> tools (Tavily only) or END
|
|
workflow = StateGraph(ChatState)
|
|
workflow.add_node("planner", planner_node)
|
|
workflow.add_node("agent_sql_plan", agent_sql_plan_node)
|
|
workflow.add_node("parse_execute_sql", parse_execute_sql_node)
|
|
workflow.add_node("agent_formulate", agent_formulate_node)
|
|
workflow.add_node("agent_tavily", agent_tavily_node)
|
|
workflow.add_node("agent_answer", agent_answer_node)
|
|
workflow.add_node("tools", tools_with_retry)
|
|
workflow.add_edge(START, "planner")
|
|
workflow.add_conditional_edges("planner", route_by_plan)
|
|
# SQL path: agent_sql_plan -> parse_execute_sql -> agent_formulate -> END (no tools, /api/analyze compatible)
|
|
workflow.add_edge("agent_sql_plan", "parse_execute_sql")
|
|
workflow.add_edge("parse_execute_sql", "agent_formulate")
|
|
workflow.add_edge("agent_formulate", END)
|
|
workflow.add_conditional_edges("agent_tavily", should_continue_tavily)
|
|
workflow.add_edge("agent_answer", END)
|
|
workflow.add_conditional_edges("tools", route_back)
|
|
return workflow.compile(checkpointer=memory)
|
|
|
|
async def chat(self, message: str, chat_id: str = "default") -> List[BaseMessage]:
|
|
"""Processes a chat message by calling the LLM and tools and returns the chat history.
|
|
|
|
Args:
|
|
message: The user message to process.
|
|
chat_id: The chat thread ID.
|
|
|
|
Returns:
|
|
The list of messages in the chat history.
|
|
"""
|
|
# Set the right thread ID for memory
|
|
config = {"configurable": {"thread_id": chat_id}}
|
|
|
|
# Single-turn chat (non-streaming)
|
|
result = await self.app.ainvoke(
|
|
{"messages": [HumanMessage(content=message)]}, config=config
|
|
)
|
|
|
|
# Extract and return the messages from the result
|
|
return result["messages"]
|
|
|
|
async def stream_events(
|
|
self, *, message: str, chat_id: str = "default"
|
|
) -> AsyncIterator[dict]:
|
|
"""Stream UI-focused events using astream_events v2.
|
|
|
|
Args:
|
|
message: The user message to process.
|
|
chat_id: Logical thread identifier; forwarded in the runnable config so
|
|
memory and tools are scoped per thread.
|
|
|
|
Yields:
|
|
dict: One of:
|
|
- ``{"type": "status", "label": str}`` for short progress updates.
|
|
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
|
|
where ``chat_history`` only includes ``user``/``assistant`` roles.
|
|
- ``{"type": "error", "message": str}`` if an exception occurs.
|
|
"""
|
|
# Thread-aware config for LangGraph/LangChain
|
|
config = {"configurable": {"thread_id": chat_id}}
|
|
|
|
def _is_root(ev: dict) -> bool:
|
|
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
|
return not ev.get("parent_ids")
|
|
|
|
try:
|
|
async for event in self.app.astream_events(
|
|
{"messages": [HumanMessage(content=message)]},
|
|
config=config,
|
|
version="v2",
|
|
):
|
|
etype = event.get("event")
|
|
ename = event.get("name") or ""
|
|
edata = event.get("data") or {}
|
|
|
|
# Stream human-readable progress via the special send_streaming_message tool
|
|
# Match the legacy implementation exactly (line 267-272 in legacy/chatbot.py)
|
|
if etype == "on_tool_start":
|
|
# Log all tool starts to debug
|
|
logger.debug(f"Tool start event: name='{ename}', event='{etype}'")
|
|
if ename == "send_streaming_message":
|
|
tool_in = edata.get("input") or {}
|
|
msg = tool_in.get("message")
|
|
logger.info(f"send_streaming_message tool called with input: {tool_in}")
|
|
if isinstance(msg, str) and msg.strip():
|
|
logger.info(f"Status-Update gesendet: {msg.strip()}")
|
|
yield {"type": "status", "label": msg.strip()}
|
|
continue
|
|
|
|
# Emit the final payload when the root run finishes
|
|
if etype == "on_chain_end" and _is_root(event):
|
|
output_obj = edata.get("output")
|
|
|
|
# Extract message list from the graph's final output
|
|
final_msgs = ChatStreamingHelper.extract_messages_from_output(
|
|
output_obj=output_obj
|
|
)
|
|
|
|
# Normalize for the frontend (only user/assistant with text content)
|
|
# Exclude planner-only and SQL-path intermediate messages from chat display
|
|
_planner_only = frozenset(("sql", "tavily", "both", "none"))
|
|
chat_history_payload: List[dict] = []
|
|
for m in final_msgs:
|
|
if isinstance(m, BaseMessage):
|
|
d = ChatStreamingHelper.message_to_dict(msg=m)
|
|
elif isinstance(m, dict):
|
|
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
|
|
else:
|
|
continue
|
|
if d.get("role") not in ("user", "assistant") or not d.get("content"):
|
|
continue
|
|
content = (d.get("content") or "").strip()
|
|
if d.get("role") == "assistant" and content.lower() in _planner_only:
|
|
continue # Skip planner routing message
|
|
# Skip agent_sql_plan output: ```sql block OR raw SQL (SELECT...FROM/JOIN)
|
|
if d.get("role") == "assistant":
|
|
cu = content.upper()
|
|
if content.startswith("```") or (
|
|
cu.startswith("SELECT") and ("FROM" in cu or "JOIN" in cu)
|
|
):
|
|
continue
|
|
# Strip leaked chat template tokens (<|im_start|> etc.) from assistant messages
|
|
content = _sanitize_llm_response(content)
|
|
if not content:
|
|
continue
|
|
d = {**d, "content": content}
|
|
chat_history_payload.append(d)
|
|
|
|
yield {
|
|
"type": "final",
|
|
"response": {
|
|
"thread": chat_id,
|
|
"chat_history": chat_history_payload,
|
|
},
|
|
}
|
|
return
|
|
|
|
except Exception as exc:
|
|
# Emit a single error envelope and end the stream
|
|
logger.error(f"Exception in stream_events: {exc}", exc_info=True)
|
|
yield {"type": "error", "message": f"Fehler beim Verarbeiten: {exc}"}
|