307 lines
11 KiB
Python
307 lines
11 KiB
Python
"""Chatbot domain logic."""
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Annotated, AsyncIterator, Any
|
|
from pydantic import BaseModel
|
|
|
|
from langchain_core.messages import (
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
trim_messages,
|
|
)
|
|
from langgraph.graph.message import add_messages
|
|
|
|
# ^ add_messages aggregator keeps history in state
|
|
from langgraph.graph import StateGraph, START, END
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
from langgraph.prebuilt import ToolNode
|
|
|
|
from langchain_tavily import TavilyExtract, TavilySearch
|
|
from langchain_core.tools import tool
|
|
|
|
from src.chat.domain.sqlitetool import SQLiteTool
|
|
from src.chat.domain.streaming_helper import ChatStreamingHelper
|
|
from src.settings import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@tool
|
|
def send_streaming_message(message: str) -> str:
|
|
"""Send a streaming message to the user to provide updates during processing.
|
|
|
|
Use this tool to send short status updates to the user while you are working
|
|
on their request. This helps keep the user informed about what you are doing.
|
|
|
|
Args:
|
|
message: A short German message describing what you are currently doing.
|
|
Examples: "Durchsuche Datenbank nach Lampen, LED, Leuchten, und Ähnlichem."
|
|
"Suche im Internet nach Produktinformationen."
|
|
"Analysiere Suchergebnisse."
|
|
|
|
Returns:
|
|
A confirmation that the message was sent.
|
|
"""
|
|
# This tool doesn't actually do anything - it's just for the AI to signal
|
|
# what it's doing to the frontend via the tool call mechanism
|
|
return f"Status-Update gesendet: {message}"
|
|
|
|
|
|
class ChatState(BaseModel):
|
|
"""Represents the state of a chat session."""
|
|
|
|
messages: Annotated[list[BaseMessage], add_messages]
|
|
|
|
|
|
@dataclass
|
|
class Chatbot:
|
|
"""Represents a chatbot."""
|
|
|
|
model: Any
|
|
memory: Any
|
|
app: Any = None
|
|
system_prompt: str = "You are a helpful assistant."
|
|
|
|
@classmethod
|
|
async def create(
|
|
cls,
|
|
model: Any,
|
|
memory: Any,
|
|
system_prompt: str,
|
|
) -> "Chatbot":
|
|
"""Factory method to create and configure a Chatbot instance.
|
|
|
|
Args:
|
|
model: The chat model to use.
|
|
memory: The chat memory to use.
|
|
system_prompt: The system prompt to initialize the chatbot.
|
|
|
|
Returns:
|
|
A configured Chatbot instance.
|
|
"""
|
|
instance = Chatbot(
|
|
model=model,
|
|
memory=memory,
|
|
system_prompt=system_prompt,
|
|
)
|
|
configured_tools = await instance._configure_tools()
|
|
instance.app = instance._build_app(memory, configured_tools)
|
|
return instance
|
|
|
|
async def _configure_tools(self) -> list:
|
|
tavily_search_tool = TavilySearch(
|
|
tavily_api_key=settings.tavily_api_key,
|
|
max_results=settings.tavily_max_results,
|
|
include_answer=settings.tavily_answer,
|
|
include_images=settings.tavily_include_images,
|
|
include_image_descriptions=settings.tavily_include_image_descriptions,
|
|
search_depth=settings.tavily_search_depth,
|
|
country=settings.tavily_country,
|
|
)
|
|
tavily_extract_tool = TavilyExtract(tavily_api_key=settings.tavily_api_key)
|
|
sqlite_tool = await SQLiteTool.create(
|
|
api_key=settings.pp_query_api_key,
|
|
base_url=settings.pp_query_base_url,
|
|
)
|
|
return [
|
|
sqlite_tool.get_tool(),
|
|
tavily_search_tool,
|
|
tavily_extract_tool,
|
|
send_streaming_message,
|
|
]
|
|
|
|
def _build_app(
|
|
self, memory: Any, tools: list
|
|
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
|
|
"""Builds the chatbot application workflow using LangGraph.
|
|
|
|
Args:
|
|
memory: The chat memory to use.
|
|
tools: The list of tools the chatbot can use.
|
|
|
|
Returns:
|
|
A compiled state graph representing the chatbot application.
|
|
"""
|
|
llm_with_tools = self.model.bind_tools(tools=tools)
|
|
|
|
def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]:
|
|
"""Selects a window of messages that fit within the context window size.
|
|
|
|
Args:
|
|
msgs: The list of messages to select from.
|
|
|
|
Returns:
|
|
A list of messages that fit within the context window size.
|
|
"""
|
|
|
|
def approx_counter(items: list[BaseMessage]) -> int:
|
|
"""Approximate token counter for messages.
|
|
|
|
Args:
|
|
items: List of messages to count tokens for.
|
|
|
|
Returns:
|
|
Approximate number of tokens in the messages.
|
|
"""
|
|
return sum(len(getattr(m, "content", "") or "") for m in items)
|
|
|
|
return trim_messages(
|
|
msgs,
|
|
strategy="last",
|
|
token_counter=approx_counter,
|
|
max_tokens=settings.context_window_token_size,
|
|
start_on="human",
|
|
end_on=("human", "tool"),
|
|
include_system=True,
|
|
)
|
|
|
|
def agent_node(state: ChatState) -> dict:
|
|
"""Agent node for the chatbot workflow.
|
|
|
|
Args:
|
|
state: The current chat state.
|
|
|
|
Returns:
|
|
The updated chat state after processing.
|
|
"""
|
|
# Select the message window to fit in context (trim if needed)
|
|
window = select_window(state.messages)
|
|
|
|
# Ensure the system prompt is present at the start
|
|
if not window or not isinstance(window[0], SystemMessage):
|
|
window = [SystemMessage(content=self.system_prompt)] + window
|
|
|
|
# Call the LLM with tools
|
|
response = llm_with_tools.invoke(window)
|
|
|
|
# Return the new state
|
|
return {"messages": [response]}
|
|
|
|
def should_continue(state: ChatState) -> str:
|
|
"""Determines whether to continue the workflow or end it.
|
|
|
|
This conditional edge is called after the agent node to decide
|
|
whether to continue to the tools node (if the last message contains
|
|
tool calls) or to end the workflow (if no tool calls are present).
|
|
|
|
Args:
|
|
state: The current chat state.
|
|
|
|
Returns:
|
|
The next node to transition to ("tools" or END).
|
|
"""
|
|
# Get the last message
|
|
last_message = state.messages[-1]
|
|
|
|
# Check if the last message contains tool calls
|
|
# If so, continue to the tools node; otherwise, end the workflow
|
|
return "tools" if getattr(last_message, "tool_calls", None) else END
|
|
|
|
# Compose the workflow
|
|
workflow = StateGraph(ChatState)
|
|
workflow.add_node("agent", agent_node)
|
|
workflow.add_node("tools", ToolNode(tools=tools))
|
|
workflow.add_edge(START, "agent")
|
|
workflow.add_conditional_edges("agent", should_continue)
|
|
workflow.add_edge("tools", "agent")
|
|
return workflow.compile(checkpointer=memory)
|
|
|
|
async def chat(self, message: str, chat_id: str = "default") -> list[BaseMessage]:
|
|
"""Processes a chat message by calling the LLM and tools and returns the chat history.
|
|
|
|
Args:
|
|
message: The user message to process.
|
|
chat_id: The chat thread ID.
|
|
|
|
Returns:
|
|
The list of messages in the chat history.
|
|
"""
|
|
# Set the right thread ID for memory
|
|
config = {"configurable": {"thread_id": chat_id}}
|
|
|
|
# Single-turn chat (non-streaming)
|
|
result = await self.app.ainvoke(
|
|
{"messages": [HumanMessage(content=message)]}, config=config
|
|
)
|
|
|
|
# Extract and return the messages from the result
|
|
return result["messages"]
|
|
|
|
async def stream_events(
|
|
self, *, message: str, chat_id: str = "default"
|
|
) -> AsyncIterator[dict]:
|
|
"""Stream UI-focused events using astream_events v2.
|
|
|
|
Args:
|
|
message: The user message to process.
|
|
chat_id: Logical thread identifier; forwarded in the runnable config so
|
|
memory and tools are scoped per thread.
|
|
|
|
Yields:
|
|
dict: One of:
|
|
- ``{"type": "status", "label": str}`` for short progress updates.
|
|
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
|
|
where ``chat_history`` only includes ``user``/``assistant`` roles.
|
|
- ``{"type": "error", "message": str}`` if an exception occurs.
|
|
"""
|
|
# Thread-aware config for LangGraph/LangChain
|
|
config = {"configurable": {"thread_id": chat_id}}
|
|
|
|
def _is_root(ev: dict) -> bool:
|
|
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
|
return not ev.get("parent_ids")
|
|
|
|
try:
|
|
async for event in self.app.astream_events(
|
|
{"messages": [HumanMessage(content=message)]},
|
|
config=config,
|
|
version="v2",
|
|
):
|
|
etype = event.get("event")
|
|
ename = event.get("name") or ""
|
|
edata = event.get("data") or {}
|
|
|
|
# Stream human-readable progress via the special send_streaming_message tool
|
|
if etype == "on_tool_start" and ename == "send_streaming_message":
|
|
tool_in = edata.get("input") or {}
|
|
msg = tool_in.get("message")
|
|
if isinstance(msg, str) and msg.strip():
|
|
yield {"type": "status", "label": msg.strip()}
|
|
continue
|
|
|
|
# Emit the final payload when the root run finishes
|
|
if etype == "on_chain_end" and _is_root(event):
|
|
output_obj = edata.get("output")
|
|
|
|
# Extract message list from the graph's final output
|
|
final_msgs = ChatStreamingHelper.extract_messages_from_output(
|
|
output_obj=output_obj
|
|
)
|
|
|
|
# Normalize for the frontend (only user/assistant with text content)
|
|
chat_history_payload: list[dict] = []
|
|
for m in final_msgs:
|
|
if isinstance(m, BaseMessage):
|
|
d = ChatStreamingHelper.message_to_dict(msg=m)
|
|
elif isinstance(m, dict):
|
|
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
|
|
else:
|
|
continue
|
|
if d.get("role") in ("user", "assistant") and d.get("content"):
|
|
chat_history_payload.append(d)
|
|
|
|
yield {
|
|
"type": "final",
|
|
"response": {
|
|
"thread": chat_id,
|
|
"chat_history": chat_history_payload,
|
|
},
|
|
}
|
|
return
|
|
|
|
except Exception as exc:
|
|
# Emit a single error envelope and end the stream
|
|
logger.error(f"Exception in stream_events: {exc}", exc_info=True)
|
|
yield {"type": "error", "message": f"Fehler beim Verarbeiten: {exc}"}
|