301 lines
11 KiB
Python
301 lines
11 KiB
Python
"""Chatbot domain logic with LangGraph integration."""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Annotated, AsyncIterator, Any
|
|
import logging
|
|
|
|
from pydantic import BaseModel
|
|
from langchain_core.messages import (
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
trim_messages,
|
|
)
|
|
from langgraph.graph.message import add_messages
|
|
from langgraph.graph import StateGraph, START, END
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
from langgraph.prebuilt import ToolNode
|
|
from langchain_anthropic import ChatAnthropic
|
|
|
|
from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper
|
|
from modules.features.chatBot.utils.toolRegistry import get_registry
|
|
from modules.shared.configuration import APP_CONFIG
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChatState(BaseModel):
|
|
"""Represents the state of a chat session."""
|
|
|
|
messages: Annotated[list[BaseMessage], add_messages]
|
|
|
|
|
|
def get_langchain_model(*, model_name: str) -> ChatAnthropic:
|
|
"""Map permission model names to LangChain ChatAnthropic models.
|
|
|
|
Args:
|
|
model_name: The model name from permissions (e.g., "claude_4_5")
|
|
|
|
Returns:
|
|
Configured ChatAnthropic instance
|
|
|
|
Raises:
|
|
ValueError: If the model name is not supported
|
|
"""
|
|
# Model name mapping
|
|
model_mapping = {
|
|
"claude_4_5": "claude-4-5-sonnet",
|
|
# Add more mappings as needed
|
|
}
|
|
|
|
anthropic_model = model_mapping.get(model_name)
|
|
if not anthropic_model:
|
|
logger.warning(
|
|
f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet"
|
|
)
|
|
anthropic_model = "claude-4-5-sonnet"
|
|
|
|
return ChatAnthropic(
|
|
model=anthropic_model,
|
|
api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"),
|
|
temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)),
|
|
max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Chatbot:
|
|
"""Represents a chatbot with LangGraph integration."""
|
|
|
|
model: Any
|
|
memory: Any
|
|
app: Any = None
|
|
system_prompt: str = "You are a helpful assistant."
|
|
context_window_size: int = 100000
|
|
|
|
@classmethod
|
|
async def create(
|
|
cls,
|
|
*,
|
|
model: Any,
|
|
memory: Any,
|
|
system_prompt: str,
|
|
tools: list,
|
|
context_window_size: int = 100000,
|
|
) -> "Chatbot":
|
|
"""Factory method to create and configure a Chatbot instance.
|
|
|
|
Args:
|
|
model: The chat model to use.
|
|
memory: The chat memory checkpointer to use.
|
|
system_prompt: The system prompt to initialize the chatbot.
|
|
tools: List of LangChain tools the chatbot can use.
|
|
context_window_size: Maximum tokens for context window.
|
|
|
|
Returns:
|
|
A configured Chatbot instance.
|
|
"""
|
|
instance = cls(
|
|
model=model,
|
|
memory=memory,
|
|
system_prompt=system_prompt,
|
|
context_window_size=context_window_size,
|
|
)
|
|
instance.app = instance._build_app(memory=memory, tools=tools)
|
|
return instance
|
|
|
|
def _build_app(
|
|
self, *, memory: Any, tools: list
|
|
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
|
|
"""Builds the chatbot application workflow using LangGraph.
|
|
|
|
Args:
|
|
memory: The chat memory checkpointer to use.
|
|
tools: The list of tools the chatbot can use.
|
|
|
|
Returns:
|
|
A compiled state graph representing the chatbot application.
|
|
"""
|
|
llm_with_tools = self.model.bind_tools(tools=tools)
|
|
|
|
def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]:
|
|
"""Selects a window of messages that fit within the context window size.
|
|
|
|
Args:
|
|
msgs: The list of messages to select from.
|
|
|
|
Returns:
|
|
A list of messages that fit within the context window size.
|
|
"""
|
|
|
|
def approx_counter(items: list[BaseMessage]) -> int:
|
|
"""Approximate token counter for messages.
|
|
|
|
Args:
|
|
items: List of messages to count tokens for.
|
|
|
|
Returns:
|
|
Approximate number of tokens in the messages.
|
|
"""
|
|
return sum(len(getattr(m, "content", "") or "") for m in items)
|
|
|
|
return trim_messages(
|
|
msgs,
|
|
strategy="last",
|
|
token_counter=approx_counter,
|
|
max_tokens=self.context_window_size,
|
|
start_on="human",
|
|
end_on=("human", "tool"),
|
|
include_system=True,
|
|
)
|
|
|
|
def agent_node(state: ChatState) -> dict:
|
|
"""Agent node for the chatbot workflow.
|
|
|
|
Args:
|
|
state: The current chat state.
|
|
|
|
Returns:
|
|
The updated chat state after processing.
|
|
"""
|
|
# Select the message window to fit in context (trim if needed)
|
|
window = select_window(state.messages)
|
|
|
|
# Ensure the system prompt is present at the start
|
|
if not window or not isinstance(window[0], SystemMessage):
|
|
window = [SystemMessage(content=self.system_prompt)] + window
|
|
|
|
# Call the LLM with tools
|
|
response = llm_with_tools.invoke(window)
|
|
|
|
# Return the new state
|
|
return {"messages": [response]}
|
|
|
|
def should_continue(state: ChatState) -> str:
|
|
"""Determines whether to continue the workflow or end it.
|
|
|
|
This conditional edge is called after the agent node to decide
|
|
whether to continue to the tools node (if the last message contains
|
|
tool calls) or to end the workflow (if no tool calls are present).
|
|
|
|
Args:
|
|
state: The current chat state.
|
|
|
|
Returns:
|
|
The next node to transition to ("tools" or END).
|
|
"""
|
|
# Get the last message
|
|
last_message = state.messages[-1]
|
|
|
|
# Check if the last message contains tool calls
|
|
# If so, continue to the tools node; otherwise, end the workflow
|
|
return "tools" if getattr(last_message, "tool_calls", None) else END
|
|
|
|
# Compose the workflow
|
|
workflow = StateGraph(ChatState)
|
|
workflow.add_node("agent", agent_node)
|
|
workflow.add_node("tools", ToolNode(tools=tools))
|
|
workflow.add_edge(START, "agent")
|
|
workflow.add_conditional_edges("agent", should_continue)
|
|
workflow.add_edge("tools", "agent")
|
|
return workflow.compile(checkpointer=memory)
|
|
|
|
async def chat(
|
|
self, *, message: str, chat_id: str = "default"
|
|
) -> list[BaseMessage]:
|
|
"""Processes a chat message and returns the chat history.
|
|
|
|
Args:
|
|
message: The user message to process.
|
|
chat_id: The chat thread ID.
|
|
|
|
Returns:
|
|
The list of messages in the chat history.
|
|
"""
|
|
# Set the right thread ID for memory
|
|
config = {"configurable": {"thread_id": chat_id}}
|
|
|
|
# Single-turn chat (non-streaming)
|
|
result = await self.app.ainvoke(
|
|
{"messages": [HumanMessage(content=message)]}, config=config
|
|
)
|
|
|
|
# Extract and return the messages from the result
|
|
return result["messages"]
|
|
|
|
async def stream_events(
|
|
self, *, message: str, chat_id: str = "default"
|
|
) -> AsyncIterator[dict]:
|
|
"""Stream UI-focused events using astream_events v2.
|
|
|
|
Args:
|
|
message: The user message to process.
|
|
chat_id: Logical thread identifier; forwarded in the runnable config so
|
|
memory and tools are scoped per thread.
|
|
|
|
Yields:
|
|
dict: One of:
|
|
- ``{"type": "status", "label": str}`` for short progress updates.
|
|
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
|
|
where ``chat_history`` only includes ``user``/``assistant`` roles.
|
|
- ``{"type": "error", "message": str}`` if an exception occurs.
|
|
"""
|
|
# Thread-aware config for LangGraph/LangChain
|
|
config = {"configurable": {"thread_id": chat_id}}
|
|
|
|
def _is_root(ev: dict) -> bool:
|
|
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
|
return not ev.get("parent_ids")
|
|
|
|
try:
|
|
async for event in self.app.astream_events(
|
|
{"messages": [HumanMessage(content=message)]},
|
|
config=config,
|
|
version="v2",
|
|
):
|
|
etype = event.get("event")
|
|
ename = event.get("name") or ""
|
|
edata = event.get("data") or {}
|
|
|
|
# Stream human-readable progress via the special send_streaming_message tool
|
|
if etype == "on_tool_start" and ename == "send_streaming_message":
|
|
tool_in = edata.get("input") or {}
|
|
msg = tool_in.get("message")
|
|
if isinstance(msg, str) and msg.strip():
|
|
yield {"type": "status", "label": msg.strip()}
|
|
continue
|
|
|
|
# Emit the final payload when the root run finishes
|
|
if etype == "on_chain_end" and _is_root(event):
|
|
output_obj = edata.get("output")
|
|
|
|
# Extract message list from the graph's final output
|
|
final_msgs = ChatStreamingHelper.extract_messages_from_output(
|
|
output_obj=output_obj
|
|
)
|
|
|
|
# Normalize for the frontend (only user/assistant with text content)
|
|
chat_history_payload: list[dict] = []
|
|
for m in final_msgs:
|
|
if isinstance(m, BaseMessage):
|
|
d = ChatStreamingHelper.message_to_dict(msg=m)
|
|
elif isinstance(m, dict):
|
|
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
|
|
else:
|
|
continue
|
|
if d.get("role") in ("user", "assistant") and d.get("content"):
|
|
chat_history_payload.append(d)
|
|
|
|
yield {
|
|
"type": "final",
|
|
"response": {
|
|
"thread": chat_id,
|
|
"chat_history": chat_history_payload,
|
|
},
|
|
}
|
|
return
|
|
|
|
except Exception as exc:
|
|
# Emit a single error envelope and end the stream
|
|
logger.error(f"Error in stream_events: {str(exc)}", exc_info=True)
|
|
yield {"type": "error", "message": f"Error processing request: {exc}"}
|