wiki/z-archive/implementation/Chatbot/legacy/chatbot.py

307 lines
11 KiB
Python

"""Chatbot domain logic."""
import logging
from dataclasses import dataclass
from typing import Annotated, AsyncIterator, Any
from pydantic import BaseModel
from langchain_core.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
trim_messages,
)
from langgraph.graph.message import add_messages
# ^ add_messages aggregator keeps history in state
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from langchain_tavily import TavilyExtract, TavilySearch
from langchain_core.tools import tool
from src.chat.domain.sqlitetool import SQLiteTool
from src.chat.domain.streaming_helper import ChatStreamingHelper
from src.settings import settings
logger = logging.getLogger(__name__)
@tool
def send_streaming_message(message: str) -> str:
"""Send a streaming message to the user to provide updates during processing.
Use this tool to send short status updates to the user while you are working
on their request. This helps keep the user informed about what you are doing.
Args:
message: A short German message describing what you are currently doing.
Examples: "Durchsuche Datenbank nach Lampen, LED, Leuchten, und Ähnlichem."
"Suche im Internet nach Produktinformationen."
"Analysiere Suchergebnisse."
Returns:
A confirmation that the message was sent.
"""
# This tool doesn't actually do anything - it's just for the AI to signal
# what it's doing to the frontend via the tool call mechanism
return f"Status-Update gesendet: {message}"
class ChatState(BaseModel):
"""Represents the state of a chat session."""
messages: Annotated[list[BaseMessage], add_messages]
@dataclass
class Chatbot:
"""Represents a chatbot."""
model: Any
memory: Any
app: Any = None
system_prompt: str = "You are a helpful assistant."
@classmethod
async def create(
cls,
model: Any,
memory: Any,
system_prompt: str,
) -> "Chatbot":
"""Factory method to create and configure a Chatbot instance.
Args:
model: The chat model to use.
memory: The chat memory to use.
system_prompt: The system prompt to initialize the chatbot.
Returns:
A configured Chatbot instance.
"""
instance = Chatbot(
model=model,
memory=memory,
system_prompt=system_prompt,
)
configured_tools = await instance._configure_tools()
instance.app = instance._build_app(memory, configured_tools)
return instance
async def _configure_tools(self) -> list:
tavily_search_tool = TavilySearch(
tavily_api_key=settings.tavily_api_key,
max_results=settings.tavily_max_results,
include_answer=settings.tavily_answer,
include_images=settings.tavily_include_images,
include_image_descriptions=settings.tavily_include_image_descriptions,
search_depth=settings.tavily_search_depth,
country=settings.tavily_country,
)
tavily_extract_tool = TavilyExtract(tavily_api_key=settings.tavily_api_key)
sqlite_tool = await SQLiteTool.create(
api_key=settings.pp_query_api_key,
base_url=settings.pp_query_base_url,
)
return [
sqlite_tool.get_tool(),
tavily_search_tool,
tavily_extract_tool,
send_streaming_message,
]
def _build_app(
self, memory: Any, tools: list
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
"""Builds the chatbot application workflow using LangGraph.
Args:
memory: The chat memory to use.
tools: The list of tools the chatbot can use.
Returns:
A compiled state graph representing the chatbot application.
"""
llm_with_tools = self.model.bind_tools(tools=tools)
def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]:
"""Selects a window of messages that fit within the context window size.
Args:
msgs: The list of messages to select from.
Returns:
A list of messages that fit within the context window size.
"""
def approx_counter(items: list[BaseMessage]) -> int:
"""Approximate token counter for messages.
Args:
items: List of messages to count tokens for.
Returns:
Approximate number of tokens in the messages.
"""
return sum(len(getattr(m, "content", "") or "") for m in items)
return trim_messages(
msgs,
strategy="last",
token_counter=approx_counter,
max_tokens=settings.context_window_token_size,
start_on="human",
end_on=("human", "tool"),
include_system=True,
)
def agent_node(state: ChatState) -> dict:
"""Agent node for the chatbot workflow.
Args:
state: The current chat state.
Returns:
The updated chat state after processing.
"""
# Select the message window to fit in context (trim if needed)
window = select_window(state.messages)
# Ensure the system prompt is present at the start
if not window or not isinstance(window[0], SystemMessage):
window = [SystemMessage(content=self.system_prompt)] + window
# Call the LLM with tools
response = llm_with_tools.invoke(window)
# Return the new state
return {"messages": [response]}
def should_continue(state: ChatState) -> str:
"""Determines whether to continue the workflow or end it.
This conditional edge is called after the agent node to decide
whether to continue to the tools node (if the last message contains
tool calls) or to end the workflow (if no tool calls are present).
Args:
state: The current chat state.
Returns:
The next node to transition to ("tools" or END).
"""
# Get the last message
last_message = state.messages[-1]
# Check if the last message contains tool calls
# If so, continue to the tools node; otherwise, end the workflow
return "tools" if getattr(last_message, "tool_calls", None) else END
# Compose the workflow
workflow = StateGraph(ChatState)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", ToolNode(tools=tools))
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue)
workflow.add_edge("tools", "agent")
return workflow.compile(checkpointer=memory)
async def chat(self, message: str, chat_id: str = "default") -> list[BaseMessage]:
"""Processes a chat message by calling the LLM and tools and returns the chat history.
Args:
message: The user message to process.
chat_id: The chat thread ID.
Returns:
The list of messages in the chat history.
"""
# Set the right thread ID for memory
config = {"configurable": {"thread_id": chat_id}}
# Single-turn chat (non-streaming)
result = await self.app.ainvoke(
{"messages": [HumanMessage(content=message)]}, config=config
)
# Extract and return the messages from the result
return result["messages"]
async def stream_events(
self, *, message: str, chat_id: str = "default"
) -> AsyncIterator[dict]:
"""Stream UI-focused events using astream_events v2.
Args:
message: The user message to process.
chat_id: Logical thread identifier; forwarded in the runnable config so
memory and tools are scoped per thread.
Yields:
dict: One of:
- ``{"type": "status", "label": str}`` for short progress updates.
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
where ``chat_history`` only includes ``user``/``assistant`` roles.
- ``{"type": "error", "message": str}`` if an exception occurs.
"""
# Thread-aware config for LangGraph/LangChain
config = {"configurable": {"thread_id": chat_id}}
def _is_root(ev: dict) -> bool:
"""Return True if the event is from the root run (v2: empty parent_ids)."""
return not ev.get("parent_ids")
try:
async for event in self.app.astream_events(
{"messages": [HumanMessage(content=message)]},
config=config,
version="v2",
):
etype = event.get("event")
ename = event.get("name") or ""
edata = event.get("data") or {}
# Stream human-readable progress via the special send_streaming_message tool
if etype == "on_tool_start" and ename == "send_streaming_message":
tool_in = edata.get("input") or {}
msg = tool_in.get("message")
if isinstance(msg, str) and msg.strip():
yield {"type": "status", "label": msg.strip()}
continue
# Emit the final payload when the root run finishes
if etype == "on_chain_end" and _is_root(event):
output_obj = edata.get("output")
# Extract message list from the graph's final output
final_msgs = ChatStreamingHelper.extract_messages_from_output(
output_obj=output_obj
)
# Normalize for the frontend (only user/assistant with text content)
chat_history_payload: list[dict] = []
for m in final_msgs:
if isinstance(m, BaseMessage):
d = ChatStreamingHelper.message_to_dict(msg=m)
elif isinstance(m, dict):
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
else:
continue
if d.get("role") in ("user", "assistant") and d.get("content"):
chat_history_payload.append(d)
yield {
"type": "final",
"response": {
"thread": chat_id,
"chat_history": chat_history_payload,
},
}
return
except Exception as exc:
# Emit a single error envelope and end the stream
logger.error(f"Exception in stream_events: {exc}", exc_info=True)
yield {"type": "error", "message": f"Fehler beim Verarbeiten: {exc}"}