feat: chatbot w/ streaming basics

This commit is contained in:
Christopher Gondek 2025-10-03 16:48:33 +02:00
parent 8707203ac2
commit 4bfeded9d0
12 changed files with 1315 additions and 159 deletions

17
app.py
View file

@ -240,9 +240,26 @@ instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Application is starting up")
# Initialize LangGraph checkpointer
from modules.features.chatBot.utils.checkpointer import (
initialize_checkpointer,
close_checkpointer,
)
try:
await initialize_checkpointer()
logger.info("LangGraph checkpointer initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize LangGraph checkpointer: {str(e)}")
# Continue startup even if checkpointer fails to initialize
eventManager.start()
yield
# Cleanup
eventManager.stop()
await close_checkpointer()
logger.info("Application has been shut down")

View file

@ -8,10 +8,18 @@ import uuid
class ChatStat(BaseModel, ModelMixin):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
workflowId: Optional[str] = Field(None, description="Foreign key to workflow (for workflow stats)")
messageId: Optional[str] = Field(None, description="Foreign key to message (for message stats)")
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
)
workflowId: Optional[str] = Field(
None, description="Foreign key to workflow (for workflow stats)"
)
messageId: Optional[str] = Field(
None, description="Foreign key to message (for message stats)"
)
processingTime: Optional[float] = Field(
None, description="Processing time in seconds"
)
tokenCount: Optional[int] = Field(None, description="Number of tokens processed")
bytesSent: Optional[int] = Field(None, description="Number of bytes sent")
bytesReceived: Optional[int] = Field(None, description="Number of bytes received")
@ -37,14 +45,23 @@ register_model_labels(
class ChatLog(BaseModel, ModelMixin):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
)
workflowId: str = Field(description="Foreign key to workflow")
message: str = Field(description="Log message")
type: str = Field(description="Log type (info, warning, error, etc.)")
timestamp: float = Field(default_factory=get_utc_timestamp, description="When the log entry was created (UTC timestamp in seconds)")
timestamp: float = Field(
default_factory=get_utc_timestamp,
description="When the log entry was created (UTC timestamp in seconds)",
)
status: Optional[str] = Field(None, description="Status of the log entry")
progress: Optional[float] = Field(None, description="Progress indicator (0.0 to 1.0)")
performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics")
progress: Optional[float] = Field(
None, description="Progress indicator (0.0 to 1.0)"
)
performance: Optional[Dict[str, Any]] = Field(
None, description="Performance metrics"
)
register_model_labels(
@ -64,7 +81,9 @@ register_model_labels(
class ChatDocument(BaseModel, ModelMixin):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
)
messageId: str = Field(description="Foreign key to message")
fileId: str = Field(description="Foreign key to file")
fileName: str = Field(description="Name of the file")
@ -73,7 +92,9 @@ class ChatDocument(BaseModel, ModelMixin):
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
taskNumber: Optional[int] = Field(None, description="Task number within round")
actionNumber: Optional[int] = Field(None, description="Action number within task")
actionId: Optional[str] = Field(None, description="ID of the action that created this document")
actionId: Optional[str] = Field(
None, description="ID of the action that created this document"
)
register_model_labels(
@ -96,13 +117,19 @@ register_model_labels(
class ContentMetadata(BaseModel, ModelMixin):
size: int = Field(description="Content size in bytes")
pages: Optional[int] = Field(None, description="Number of pages for multi-page content")
pages: Optional[int] = Field(
None, description="Number of pages for multi-page content"
)
error: Optional[str] = Field(None, description="Processing error if any")
width: Optional[int] = Field(None, description="Width in pixels for images/videos")
height: Optional[int] = Field(None, description="Height in pixels for images/videos")
height: Optional[int] = Field(
None, description="Height in pixels for images/videos"
)
colorMode: Optional[str] = Field(None, description="Color mode")
fps: Optional[float] = Field(None, description="Frames per second for videos")
durationSec: Optional[float] = Field(None, description="Duration in seconds for media")
durationSec: Optional[float] = Field(
None, description="Duration in seconds for media"
)
mimeType: str = Field(description="MIME type of the content")
base64Encoded: bool = Field(description="Whether the data is base64 encoded")
@ -144,7 +171,9 @@ register_model_labels(
class ExtractedContent(BaseModel, ModelMixin):
id: str = Field(description="Reference to source ChatDocument")
contents: List[ContentItem] = Field(default_factory=list, description="List of content items")
contents: List[ContentItem] = Field(
default_factory=list, description="List of content items"
)
register_model_labels(
@ -156,27 +185,53 @@ register_model_labels(
},
)
class ChatMessage(BaseModel, ModelMixin):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
)
workflowId: str = Field(description="Foreign key to workflow")
parentMessageId: Optional[str] = Field(None, description="Parent message ID for threading")
documents: List[ChatDocument] = Field(default_factory=list, description="Associated documents")
documentsLabel: Optional[str] = Field(None, description="Label for the set of documents")
parentMessageId: Optional[str] = Field(
None, description="Parent message ID for threading"
)
documents: List[ChatDocument] = Field(
default_factory=list, description="Associated documents"
)
documentsLabel: Optional[str] = Field(
None, description="Label for the set of documents"
)
message: Optional[str] = Field(None, description="Message content")
role: str = Field(description="Role of the message sender")
status: str = Field(description="Status of the message (first, step, last)")
sequenceNr: int = Field(description="Sequence number of the message (set automatically)")
publishedAt: float = Field(default_factory=get_utc_timestamp, description="When the message was published (UTC timestamp in seconds)")
sequenceNr: int = Field(
description="Sequence number of the message (set automatically)"
)
publishedAt: float = Field(
default_factory=get_utc_timestamp,
description="When the message was published (UTC timestamp in seconds)",
)
stats: Optional[ChatStat] = Field(None, description="Statistics for this message")
success: Optional[bool] = Field(None, description="Whether the message processing was successful")
actionId: Optional[str] = Field(None, description="ID of the action that produced this message")
actionMethod: Optional[str] = Field(None, description="Method of the action that produced this message")
actionName: Optional[str] = Field(None, description="Name of the action that produced this message")
success: Optional[bool] = Field(
None, description="Whether the message processing was successful"
)
actionId: Optional[str] = Field(
None, description="ID of the action that produced this message"
)
actionMethod: Optional[str] = Field(
None, description="Method of the action that produced this message"
)
actionName: Optional[str] = Field(
None, description="Name of the action that produced this message"
)
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
taskNumber: Optional[int] = Field(None, description="Task number within round")
actionNumber: Optional[int] = Field(None, description="Action number within task")
taskProgress: Optional[str] = Field(None, description="Task progress status: pending, running, success, fail, retry")
actionProgress: Optional[str] = Field(None, description="Action progress status: pending, running, success, fail")
taskProgress: Optional[str] = Field(
None, description="Task progress status: pending, running, success, fail, retry"
)
actionProgress: Optional[str] = Field(
None, description="Action progress status: pending, running, success, fail"
)
register_model_labels(
@ -208,31 +263,135 @@ register_model_labels(
class ChatWorkflow(BaseModel, ModelMixin):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", frontend_type="text", frontend_readonly=True, frontend_required=False)
mandateId: str = Field(description="ID of the mandate this workflow belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False)
status: str = Field(description="Current status of the workflow", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Primary key",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
mandateId: str = Field(
description="ID of the mandate this workflow belongs to",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
status: str = Field(
description="Current status of the workflow",
frontend_type="select",
frontend_readonly=False,
frontend_required=False,
frontend_options=[
{"value": "running", "label": {"en": "Running", "fr": "En cours"}},
{"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}},
{"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}},
{"value": "error", "label": {"en": "Error", "fr": "Erreur"}},
])
name: Optional[str] = Field(None, description="Name of the workflow", frontend_type="text", frontend_readonly=False, frontend_required=True)
currentRound: int = Field(description="Current round number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
currentTask: int = Field(default=0, description="Current task number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
currentAction: int = Field(default=0, description="Current action number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
totalTasks: int = Field(default=0, description="Total number of tasks in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False)
totalActions: int = Field(default=0, description="Total number of actions in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False)
lastActivity: float = Field(default_factory=get_utc_timestamp, description="Timestamp of last activity (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
startedAt: float = Field(default_factory=get_utc_timestamp, description="When the workflow started (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
logs: List[ChatLog] = Field(default_factory=list, description="Workflow logs", frontend_type="text", frontend_readonly=True, frontend_required=False)
messages: List[ChatMessage] = Field(default_factory=list, description="Messages in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False)
stats: Optional[ChatStat] = Field(None, description="Workflow statistics", frontend_type="text", frontend_readonly=True, frontend_required=False)
tasks: list = Field(default_factory=list, description="List of tasks in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False)
workflowMode: str = Field(default="Actionplan", description="Workflow mode selector", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
{"value": "Actionplan", "label": {"en": "Action Plan", "fr": "Plan d'actions"}},
],
)
name: Optional[str] = Field(
None,
description="Name of the workflow",
frontend_type="text",
frontend_readonly=False,
frontend_required=True,
)
currentRound: int = Field(
description="Current round number",
frontend_type="integer",
frontend_readonly=True,
frontend_required=False,
)
currentTask: int = Field(
default=0,
description="Current task number",
frontend_type="integer",
frontend_readonly=True,
frontend_required=False,
)
currentAction: int = Field(
default=0,
description="Current action number",
frontend_type="integer",
frontend_readonly=True,
frontend_required=False,
)
totalTasks: int = Field(
default=0,
description="Total number of tasks in the workflow",
frontend_type="integer",
frontend_readonly=True,
frontend_required=False,
)
totalActions: int = Field(
default=0,
description="Total number of actions in the workflow",
frontend_type="integer",
frontend_readonly=True,
frontend_required=False,
)
lastActivity: float = Field(
default_factory=get_utc_timestamp,
description="Timestamp of last activity (UTC timestamp in seconds)",
frontend_type="timestamp",
frontend_readonly=True,
frontend_required=False,
)
startedAt: float = Field(
default_factory=get_utc_timestamp,
description="When the workflow started (UTC timestamp in seconds)",
frontend_type="timestamp",
frontend_readonly=True,
frontend_required=False,
)
logs: List[ChatLog] = Field(
default_factory=list,
description="Workflow logs",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
messages: List[ChatMessage] = Field(
default_factory=list,
description="Messages in the workflow",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
stats: Optional[ChatStat] = Field(
None,
description="Workflow statistics",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
tasks: list = Field(
default_factory=list,
description="List of tasks in the workflow",
frontend_type="text",
frontend_readonly=True,
frontend_required=False,
)
workflowMode: str = Field(
default="Actionplan",
description="Workflow mode selector",
frontend_type="select",
frontend_readonly=False,
frontend_required=False,
frontend_options=[
{
"value": "Actionplan",
"label": {"en": "Action Plan", "fr": "Plan d'actions"},
},
{"value": "React", "label": {"en": "React", "fr": "Réactif"}},
])
maxSteps: int = Field(default=5, description="Maximum number of iterations in react mode", frontend_type="integer", frontend_readonly=False, frontend_required=False)
],
)
maxSteps: int = Field(
default=5,
description="Maximum number of iterations in react mode",
frontend_type="integer",
frontend_readonly=False,
frontend_required=False,
)
register_model_labels(
@ -278,7 +437,10 @@ register_model_labels(
"completed_tasks": {"en": "Completed Tasks", "fr": "Tâches terminées"},
"total_tasks": {"en": "Total Tasks", "fr": "Total des tâches"},
"execution_time": {"en": "Execution Time", "fr": "Temps d'exécution"},
"final_results_count": {"en": "Final Results Count", "fr": "Nombre de résultats finaux"},
"final_results_count": {
"en": "Final Results Count",
"fr": "Nombre de résultats finaux",
},
"error": {"en": "Error", "fr": "Erreur"},
"phase": {"en": "Phase", "fr": "Phase"},
},
@ -300,5 +462,3 @@ register_model_labels(
"userLanguage": {"en": "User Language", "fr": "Langue de l'utilisateur"},
},
)

View file

@ -0,0 +1,130 @@
"""Chatbot API models for request/response handling."""
from typing import List, Optional
from pydantic import BaseModel, Field
from modules.shared.attributeUtils import register_model_labels, ModelMixin
# Chatbot API Models
class MessageItem(BaseModel, ModelMixin):
"""Individual message in a thread"""
role: str = Field(..., description="Message role (user or assistant)")
content: str = Field(..., description="Message content")
timestamp: float = Field(..., description="Message timestamp (Unix timestamp)")
class ChatMessageRequest(BaseModel, ModelMixin):
"""Request model for posting a chat message"""
thread_id: Optional[str] = Field(
None, description="Thread ID (creates new thread if not provided)"
)
message: str = Field(..., description="User message content")
class ChatMessageResponse(BaseModel, ModelMixin):
"""Response model for posting a chat message"""
thread_id: str = Field(..., description="Thread ID")
messages: List[MessageItem] = Field(..., description="All messages in thread")
class ThreadSummary(BaseModel, ModelMixin):
"""Summary of a chat thread for list view"""
thread_id: str = Field(..., description="Thread ID")
created_at: float = Field(..., description="Thread creation timestamp")
last_message: str = Field(..., description="Last message content")
message_count: int = Field(..., description="Total number of messages")
class ThreadListResponse(BaseModel, ModelMixin):
"""Response model for listing all threads"""
threads: List[ThreadSummary] = Field(..., description="List of thread summaries")
class ThreadDetail(BaseModel, ModelMixin):
"""Detailed view of a single thread"""
thread_id: str = Field(..., description="Thread ID")
created_at: float = Field(..., description="Thread creation timestamp")
messages: List[MessageItem] = Field(
..., description="All messages in chronological order"
)
class DeleteResponse(BaseModel, ModelMixin):
"""Response model for delete operations"""
message: str = Field(..., description="Confirmation message")
thread_id: str = Field(..., description="Deleted thread ID")
# Register model labels for internationalization
register_model_labels(
"MessageItem",
{"en": "Message Item", "fr": "Élément de message"},
{
"role": {"en": "Role", "fr": "Rôle"},
"content": {"en": "Content", "fr": "Contenu"},
"timestamp": {"en": "Timestamp", "fr": "Horodatage"},
},
)
register_model_labels(
"ChatMessageRequest",
{"en": "Chat Message Request", "fr": "Demande de message de chat"},
{
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
"message": {"en": "Message", "fr": "Message"},
},
)
register_model_labels(
"ChatMessageResponse",
{"en": "Chat Message Response", "fr": "Réponse du message de chat"},
{
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
"messages": {"en": "Messages", "fr": "Messages"},
},
)
register_model_labels(
"ThreadSummary",
{"en": "Thread Summary", "fr": "Résumé du fil"},
{
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
"created_at": {"en": "Created At", "fr": "Créé le"},
"last_message": {"en": "Last Message", "fr": "Dernier message"},
"message_count": {"en": "Message Count", "fr": "Nombre de messages"},
},
)
register_model_labels(
"ThreadListResponse",
{"en": "Thread List Response", "fr": "Réponse de liste de fils"},
{
"threads": {"en": "Threads", "fr": "Fils"},
},
)
register_model_labels(
"ThreadDetail",
{"en": "Thread Detail", "fr": "Détail du fil"},
{
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
"created_at": {"en": "Created At", "fr": "Créé le"},
"messages": {"en": "Messages", "fr": "Messages"},
},
)
register_model_labels(
"DeleteResponse",
{"en": "Delete Response", "fr": "Réponse de suppression"},
{
"message": {"en": "Message", "fr": "Message"},
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
},
)

View file

@ -0,0 +1,24 @@
"""Tool for sending streaming status updates to users."""
from langchain_core.tools import tool
@tool
def send_streaming_message(message: str) -> str:
"""Send a streaming message to the user to provide updates during processing.
Use this tool to send short status updates to the user while you are working
on their request. This helps keep the user informed about what you are doing.
Args:
message: A short message describing what you are currently doing.
Examples: "Searching database for relevant information..."
"Analyzing search results..."
"Processing your request..."
Returns:
A confirmation that the message was sent.
"""
# This tool doesn't actually do anything - it's just for the AI to signal
# what it's doing to the frontend via the tool call mechanism
return f"Status update sent: {message}"

View file

@ -0,0 +1 @@
"""Domain logic for chatbot functionality."""

View file

@ -0,0 +1,301 @@
"""Chatbot domain logic with LangGraph integration."""
from dataclasses import dataclass
from typing import Annotated, AsyncIterator, Any
import logging
from pydantic import BaseModel
from langchain_core.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
trim_messages,
)
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from langchain_anthropic import ChatAnthropic
from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper
from modules.features.chatBot.utils.toolRegistry import get_registry
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
class ChatState(BaseModel):
"""Represents the state of a chat session."""
messages: Annotated[list[BaseMessage], add_messages]
def get_langchain_model(*, model_name: str) -> ChatAnthropic:
"""Map permission model names to LangChain ChatAnthropic models.
Args:
model_name: The model name from permissions (e.g., "claude_4_5")
Returns:
Configured ChatAnthropic instance
Raises:
ValueError: If the model name is not supported
"""
# Model name mapping
model_mapping = {
"claude_4_5": "claude-4-5-sonnet",
# Add more mappings as needed
}
anthropic_model = model_mapping.get(model_name)
if not anthropic_model:
logger.warning(
f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet"
)
anthropic_model = "claude-4-5-sonnet"
return ChatAnthropic(
model=anthropic_model,
api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"),
temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)),
max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)),
)
@dataclass
class Chatbot:
"""Represents a chatbot with LangGraph integration."""
model: Any
memory: Any
app: Any = None
system_prompt: str = "You are a helpful assistant."
context_window_size: int = 100000
@classmethod
async def create(
cls,
*,
model: Any,
memory: Any,
system_prompt: str,
tools: list,
context_window_size: int = 100000,
) -> "Chatbot":
"""Factory method to create and configure a Chatbot instance.
Args:
model: The chat model to use.
memory: The chat memory checkpointer to use.
system_prompt: The system prompt to initialize the chatbot.
tools: List of LangChain tools the chatbot can use.
context_window_size: Maximum tokens for context window.
Returns:
A configured Chatbot instance.
"""
instance = cls(
model=model,
memory=memory,
system_prompt=system_prompt,
context_window_size=context_window_size,
)
instance.app = instance._build_app(memory=memory, tools=tools)
return instance
def _build_app(
self, *, memory: Any, tools: list
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
"""Builds the chatbot application workflow using LangGraph.
Args:
memory: The chat memory checkpointer to use.
tools: The list of tools the chatbot can use.
Returns:
A compiled state graph representing the chatbot application.
"""
llm_with_tools = self.model.bind_tools(tools=tools)
def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]:
"""Selects a window of messages that fit within the context window size.
Args:
msgs: The list of messages to select from.
Returns:
A list of messages that fit within the context window size.
"""
def approx_counter(items: list[BaseMessage]) -> int:
"""Approximate token counter for messages.
Args:
items: List of messages to count tokens for.
Returns:
Approximate number of tokens in the messages.
"""
return sum(len(getattr(m, "content", "") or "") for m in items)
return trim_messages(
msgs,
strategy="last",
token_counter=approx_counter,
max_tokens=self.context_window_size,
start_on="human",
end_on=("human", "tool"),
include_system=True,
)
def agent_node(state: ChatState) -> dict:
"""Agent node for the chatbot workflow.
Args:
state: The current chat state.
Returns:
The updated chat state after processing.
"""
# Select the message window to fit in context (trim if needed)
window = select_window(state.messages)
# Ensure the system prompt is present at the start
if not window or not isinstance(window[0], SystemMessage):
window = [SystemMessage(content=self.system_prompt)] + window
# Call the LLM with tools
response = llm_with_tools.invoke(window)
# Return the new state
return {"messages": [response]}
def should_continue(state: ChatState) -> str:
"""Determines whether to continue the workflow or end it.
This conditional edge is called after the agent node to decide
whether to continue to the tools node (if the last message contains
tool calls) or to end the workflow (if no tool calls are present).
Args:
state: The current chat state.
Returns:
The next node to transition to ("tools" or END).
"""
# Get the last message
last_message = state.messages[-1]
# Check if the last message contains tool calls
# If so, continue to the tools node; otherwise, end the workflow
return "tools" if getattr(last_message, "tool_calls", None) else END
# Compose the workflow
workflow = StateGraph(ChatState)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", ToolNode(tools=tools))
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue)
workflow.add_edge("tools", "agent")
return workflow.compile(checkpointer=memory)
async def chat(
self, *, message: str, chat_id: str = "default"
) -> list[BaseMessage]:
"""Processes a chat message and returns the chat history.
Args:
message: The user message to process.
chat_id: The chat thread ID.
Returns:
The list of messages in the chat history.
"""
# Set the right thread ID for memory
config = {"configurable": {"thread_id": chat_id}}
# Single-turn chat (non-streaming)
result = await self.app.ainvoke(
{"messages": [HumanMessage(content=message)]}, config=config
)
# Extract and return the messages from the result
return result["messages"]
async def stream_events(
self, *, message: str, chat_id: str = "default"
) -> AsyncIterator[dict]:
"""Stream UI-focused events using astream_events v2.
Args:
message: The user message to process.
chat_id: Logical thread identifier; forwarded in the runnable config so
memory and tools are scoped per thread.
Yields:
dict: One of:
- ``{"type": "status", "label": str}`` for short progress updates.
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
where ``chat_history`` only includes ``user``/``assistant`` roles.
- ``{"type": "error", "message": str}`` if an exception occurs.
"""
# Thread-aware config for LangGraph/LangChain
config = {"configurable": {"thread_id": chat_id}}
def _is_root(ev: dict) -> bool:
"""Return True if the event is from the root run (v2: empty parent_ids)."""
return not ev.get("parent_ids")
try:
async for event in self.app.astream_events(
{"messages": [HumanMessage(content=message)]},
config=config,
version="v2",
):
etype = event.get("event")
ename = event.get("name") or ""
edata = event.get("data") or {}
# Stream human-readable progress via the special send_streaming_message tool
if etype == "on_tool_start" and ename == "send_streaming_message":
tool_in = edata.get("input") or {}
msg = tool_in.get("message")
if isinstance(msg, str) and msg.strip():
yield {"type": "status", "label": msg.strip()}
continue
# Emit the final payload when the root run finishes
if etype == "on_chain_end" and _is_root(event):
output_obj = edata.get("output")
# Extract message list from the graph's final output
final_msgs = ChatStreamingHelper.extract_messages_from_output(
output_obj=output_obj
)
# Normalize for the frontend (only user/assistant with text content)
chat_history_payload: list[dict] = []
for m in final_msgs:
if isinstance(m, BaseMessage):
d = ChatStreamingHelper.message_to_dict(msg=m)
elif isinstance(m, dict):
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
else:
continue
if d.get("role") in ("user", "assistant") and d.get("content"):
chat_history_payload.append(d)
yield {
"type": "final",
"response": {
"thread": chat_id,
"chat_history": chat_history_payload,
},
}
return
except Exception as exc:
# Emit a single error envelope and end the stream
logger.error(f"Error in stream_events: {str(exc)}", exc_info=True)
yield {"type": "error", "message": f"Error processing request: {exc}"}

View file

@ -0,0 +1,239 @@
"""Streaming helper utilities for chat message processing and normalization."""
from typing import Any, Dict, List, Literal, Mapping, Optional
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
Role = Literal["user", "assistant", "system", "tool"]
class ChatStreamingHelper:
"""Pure helper methods for streaming and message normalization.
This class provides static utility methods for converting between different
message formats, extracting content, and normalizing message structures
for streaming chat applications.
"""
@staticmethod
def role_from_message(*, msg: BaseMessage) -> Role:
"""Extract the role from a BaseMessage instance.
Args:
msg: The BaseMessage instance to extract the role from.
Returns:
The role as a string literal: "user", "assistant", "system", or "tool".
Defaults to "assistant" if the message type is not recognized.
Examples:
>>> from langchain_core.messages import HumanMessage
>>> msg = HumanMessage(content="Hello")
>>> ChatStreamingHelper.role_from_message(msg=msg)
'user'
"""
if isinstance(msg, HumanMessage):
return "user"
if isinstance(msg, AIMessage):
return "assistant"
if isinstance(msg, SystemMessage):
return "system"
if isinstance(msg, ToolMessage):
return "tool"
return getattr(msg, "role", "assistant")
@staticmethod
def flatten_content(*, content: Any) -> str:
"""Convert complex content structures to plain text.
This method handles various content formats including strings, lists of
content parts, and dictionaries with text fields. It's designed to
normalize content from different message sources into a consistent
plain text format.
Args:
content: The content to flatten. Can be:
- str: Returned as-is after stripping whitespace
- list: Each item processed and joined with newlines
- dict: Text extracted from "text" or "content" fields
- None: Returns empty string
- Any other type: Converted to string
Returns:
The flattened content as a plain text string with whitespace stripped.
Examples:
>>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]
>>> ChatStreamingHelper.flatten_content(content=content)
'Hello
nworld'
>>> content = {"text": "Simple message"}
>>> ChatStreamingHelper.flatten_content(content=content)
'Simple message'
"""
if content is None:
return ""
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts: List[str] = []
for part in content:
if isinstance(part, dict):
if "text" in part and isinstance(part["text"], str):
parts.append(part["text"])
elif part.get("type") == "text" and isinstance(
part.get("text"), str
):
parts.append(part["text"])
elif "content" in part and isinstance(part["content"], str):
parts.append(part["content"])
else:
# Fallback for unknown dictionary structures
val = part.get("value")
if isinstance(val, str):
parts.append(val)
else:
parts.append(str(part))
return "\n".join(p.strip() for p in parts if p is not None)
if isinstance(content, dict):
if "text" in content and isinstance(content["text"], str):
return content["text"].strip()
if "content" in content and isinstance(content["content"], str):
return content["content"].strip()
return str(content).strip()
@staticmethod
def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]:
"""Convert a BaseMessage instance to a dictionary for streaming output.
This method normalizes BaseMessage instances into a consistent dictionary
format suitable for JSON serialization and streaming to clients.
Args:
msg: The BaseMessage instance to convert.
Returns:
A dictionary containing:
- "role": The message role (user, assistant, system, tool)
- "content": The flattened message content as plain text
- "tool_calls": Tool calls if present (optional)
- "name": Message name if present (optional)
Examples:
>>> from langchain_core.messages import HumanMessage
>>> msg = HumanMessage(content="Hello there")
>>> result = ChatStreamingHelper.message_to_dict(msg=msg)
>>> result["role"]
'user'
>>> result["content"]
'Hello there'
"""
payload: Dict[str, Any] = {
"role": ChatStreamingHelper.role_from_message(msg=msg),
"content": ChatStreamingHelper.flatten_content(
content=getattr(msg, "content", "")
),
}
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
payload["tool_calls"] = tool_calls
name = getattr(msg, "name", None)
if name:
payload["name"] = name
return payload
@staticmethod
def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]:
"""Convert a dictionary-shaped message to a normalized dictionary.
This method handles messages that come from serialized state and are
represented as dictionaries rather than BaseMessage instances. It
normalizes various dictionary formats into a consistent structure.
Args:
obj: The dictionary-shaped message to convert. Expected to contain
fields like "role", "type", "content", "text", etc.
Returns:
A normalized dictionary containing:
- "role": The message role (user, assistant, system, tool)
- "content": The flattened message content as plain text
- "tool_calls": Tool calls if present (optional)
- "name": Message name if present (optional)
Examples:
>>> obj = {"type": "human", "content": "Hello"}
>>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj)
>>> result["role"]
'user'
>>> result["content"]
'Hello'
"""
role: Optional[str] = obj.get("role")
if not role:
# Handle alternative type field mappings
typ = obj.get("type")
if typ in ("human", "user"):
role = "user"
elif typ in ("ai", "assistant"):
role = "assistant"
elif typ in ("system",):
role = "system"
elif typ in ("tool", "function"):
role = "tool"
content = obj.get("content")
if content is None and "text" in obj:
content = obj["text"]
out: Dict[str, Any] = {
"role": role or "assistant",
"content": ChatStreamingHelper.flatten_content(content=content),
}
if "tool_calls" in obj:
out["tool_calls"] = obj["tool_calls"]
if obj.get("name"):
out["name"] = obj["name"]
return out
@staticmethod
def extract_messages_from_output(*, output_obj: Any) -> List[Any]:
"""Extract messages from LangGraph output objects.
This method handles various output formats from LangGraph execution,
extracting the messages list from different possible structures.
Args:
output_obj: The output object from LangGraph execution. Can be:
- An object with a "messages" attribute
- A dictionary with a "messages" key
- Any other object (returns empty list)
Returns:
A list of extracted messages, or an empty list if no messages
are found or if the output object is None.
Examples:
>>> output = {"messages": [{"role": "user", "content": "Hello"}]}
>>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output)
>>> len(messages)
1
"""
if output_obj is None:
return []
# Try to parse dicts first
if isinstance(output_obj, dict):
msgs = output_obj.get("messages")
return msgs if isinstance(msgs, list) else []
# Then try to get messages attribute
msgs = getattr(output_obj, "messages", None)
return msgs if isinstance(msgs, list) else []

View file

@ -0,0 +1,215 @@
"""Service layer for chatbot functionality."""
import json
import logging
from typing import AsyncIterator, List
from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model
from modules.features.chatBot.utils.checkpointer import get_checkpointer
from modules.features.chatBot.utils.toolRegistry import get_registry
from modules.features.chatBot.utils import permissions
from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse
from modules.datamodels.datamodelUam import User
from langchain_core.messages import HumanMessage, AIMessage
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
async def post_message(
*,
thread_id: str,
message: str,
user: User,
) -> ChatMessageResponse:
"""Post a chat message to the chatbot and return the response.
Args:
thread_id: The unique identifier for the chat thread.
message: The content of the chat message.
user: The current user.
Returns:
The response containing the full chat message history and thread ID.
"""
logger.info(f"User {user.id} posted message to thread {thread_id}")
# Get user permissions
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
if not tool_ids:
raise ValueError("User does not have permission to use any chatbot tools")
model_name = permissions.get_chatbot_model(user_id=user.id)
system_prompt = permissions.get_system_prompt(user_id=user.id)
# Get tools from registry
registry = get_registry()
tools = registry.get_tool_instances(tool_ids=tool_ids)
# Get model and checkpointer
model = get_langchain_model(model_name=model_name)
checkpointer = get_checkpointer()
# Get context window size from config
context_window_size = int(
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
)
# Create chatbot instance
chatbot = await Chatbot.create(
model=model,
memory=checkpointer,
system_prompt=system_prompt,
tools=tools,
context_window_size=context_window_size,
)
# Send message to chatbot
response = await chatbot.chat(message=message, chat_id=thread_id)
# Parse the response to the correct format
messages = []
for msg in response:
# Determine the role of the message
if isinstance(msg, HumanMessage):
role = "user"
elif isinstance(msg, AIMessage):
role = "assistant"
else:
continue # Skip any other message types
# Skip messages that are structured content, such as tool calls
if not isinstance(msg.content, str):
continue
# Append message to chat history
item = MessageItem(
role=role,
content=msg.content.strip(),
timestamp=0.0, # TODO: Add proper timestamp handling
)
messages.append(item)
return ChatMessageResponse(thread_id=thread_id, messages=messages)
async def post_message_stream(
*,
thread_id: str,
message: str,
user: User,
) -> AsyncIterator[str]:
"""Post a chat message to the chatbot and stream progress updates (SSE).
Args:
thread_id: The unique identifier for the chat thread.
message: The content of the chat message.
user: The current user.
Yields:
Server-Sent Events formatted strings containing status updates and final response.
"""
logger.info(f"User {user.id} streaming message to thread {thread_id}")
try:
# Get user permissions
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
if not tool_ids:
yield (
"data: "
+ json.dumps(
{
"type": "error",
"message": "User does not have permission to use any chatbot tools",
}
)
+ "\n\n"
)
return
model_name = permissions.get_chatbot_model(user_id=user.id)
system_prompt = permissions.get_system_prompt(user_id=user.id)
# Get tools from registry
registry = get_registry()
tools = registry.get_tool_instances(tool_ids=tool_ids)
# Get model and checkpointer
model = get_langchain_model(model_name=model_name)
checkpointer = get_checkpointer()
# Get context window size from config
context_window_size = int(
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
)
# Create chatbot instance
chatbot = await Chatbot.create(
model=model,
memory=checkpointer,
system_prompt=system_prompt,
tools=tools,
context_window_size=context_window_size,
)
# Stream events from chatbot
async for event in chatbot.stream_events(message=message, chat_id=thread_id):
etype = event.get("type")
# Forward status updates
if etype == "status":
yield f"data: {json.dumps({'type': 'status', 'label': event.get('label')})}\n\n"
continue
# Forward final response
if etype == "final":
response_from_event = event.get("response") or {}
# Use the chat history from the final event (already normalized by stream_events)
chat_history_payload = response_from_event.get("chat_history", [])
if isinstance(chat_history_payload, list):
# Convert to MessageItem format
items: List[MessageItem] = []
for it in chat_history_payload:
role = it.get("role")
content = it.get("content", "")
if role in ("user", "assistant") and content:
items.append(
MessageItem(
role=role,
content=content,
timestamp=0.0, # TODO: Add proper timestamp handling
)
)
response = ChatMessageResponse(thread_id=thread_id, messages=items)
# Yield the final response and exit
yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n"
return
else:
# Unexpected payload format - log warning and return empty history
logger.warning(
f"Unexpected chat_history format in final event: {type(chat_history_payload)}"
)
response = ChatMessageResponse(thread_id=thread_id, messages=[])
yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n"
return
# Forward error events
if etype == "error":
yield f"data: {json.dumps(event)}\n\n"
return
except Exception as e:
logger.error(f"Error in streaming chat: {str(e)}", exc_info=True)
yield (
"data: "
+ json.dumps(
{
"type": "error",
"message": "An error occurred while processing your request.",
}
)
+ "\n\n"
)

View file

@ -0,0 +1,95 @@
"""PostgreSQL checkpointer utilities for LangGraph memory."""
import logging
from typing import Optional
from langgraph.checkpoint.postgres import PostgresSaver
from psycopg_pool import AsyncConnectionPool
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
# Global checkpointer instance
_checkpointer_instance: Optional[PostgresSaver] = None
_connection_pool: Optional[AsyncConnectionPool] = None
async def initialize_checkpointer() -> None:
"""Initialize the PostgreSQL checkpointer for LangGraph.
This should be called during application startup.
Creates a connection pool and PostgresSaver instance.
"""
global _checkpointer_instance, _connection_pool
if _checkpointer_instance is not None:
logger.info("Checkpointer already initialized")
return
try:
# Get database configuration from environment
host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost")
database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat")
user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev")
password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET")
port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432")
# Build connection string
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}"
# Create async connection pool
_connection_pool = AsyncConnectionPool(
conninfo=connection_string,
min_size=2,
max_size=10,
)
# Initialize the connection pool
await _connection_pool.open()
# Create PostgresSaver with the pool
_checkpointer_instance = PostgresSaver(_connection_pool)
# Setup the checkpointer (creates tables if needed)
await _checkpointer_instance.setup()
logger.info("PostgreSQL checkpointer initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}")
raise
async def close_checkpointer() -> None:
"""Close the checkpointer and connection pool.
This should be called during application shutdown.
"""
global _checkpointer_instance, _connection_pool
if _connection_pool is not None:
try:
await _connection_pool.close()
logger.info("PostgreSQL checkpointer connection pool closed")
except Exception as e:
logger.error(f"Error closing checkpointer connection pool: {str(e)}")
_checkpointer_instance = None
_connection_pool = None
def get_checkpointer() -> PostgresSaver:
"""Get the global PostgreSQL checkpointer instance.
Returns:
The initialized PostgresSaver instance
Raises:
RuntimeError: If checkpointer is not initialized
"""
if _checkpointer_instance is None:
raise RuntimeError(
"PostgreSQL checkpointer not initialized. "
"Call initialize_checkpointer() during application startup."
)
return _checkpointer_instance

View file

@ -12,36 +12,15 @@ from modules.features.chatBot.utils.toolRegistry import get_registry
# TODO: Replace these mock implementations with actual database queries
def get_allowed_tools(*, user_id: str) -> list[str]:
"""Get list of tool IDs that a user is allowed to use.
This is a mock implementation that returns all available tools
regardless of user_id. In production, this will query the database
for user-specific permissions.
Args:
user_id: The unique identifier of the user
Returns:
List of tool IDs (e.g., ["shared.tavily_search", "customer.query_althaus_database"])
"""
def get_chatbot_tools(*, user_id: str) -> list[str]:
"""Get list of tool IDs that the chatbot can use for a given user."""
registry = get_registry()
return registry.list_tool_ids()
def get_allowed_models(*, user_id: str) -> list[str]:
"""Get list of AI models that a user is allowed to use.
This is a mock implementation that returns a fixed list of models.
In production, this will query the database for user-specific model permissions.
Args:
user_id: The unique identifier of the user
Returns:
List of model identifiers (e.g., ["gpt-5", "claude-4-5"])
"""
return ["gpt-5", "claude-4-5"]
def get_chatbot_model(*, user_id: str) -> str:
"""Gets the chatbot model(s) a user is allowed to use."""
return "claude_4_5"
def get_system_prompt(*, user_id: str) -> str:

View file

@ -1,13 +1,23 @@
from pydantic import BaseModel, Field
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.requests import Request
from fastapi.responses import StreamingResponse
from typing import Any, Dict, List, Optional
from datetime import datetime
import logging
import uuid
from modules.datamodels.datamodelUam import User
from modules.datamodels.datamodelChatbot import (
ChatMessageRequest,
MessageItem,
ChatMessageResponse,
ThreadSummary,
ThreadListResponse,
ThreadDetail,
DeleteResponse,
)
from modules.security.auth import getCurrentUser, limiter
from modules.features.chatBot import service as chat_service
logger = logging.getLogger(__name__)
@ -17,68 +27,53 @@ router = APIRouter(
responses={404: {"description": "Not found"}},
)
# --- Pydantic models for requests and responses ---
class ChatMessageRequest(BaseModel):
"""Request model for posting a chat message"""
thread_id: Optional[str] = Field(
None, description="Thread ID (creates new thread if not provided)"
)
message: str = Field(..., description="User message content")
class MessageItem(BaseModel):
"""Individual message in a thread"""
role: str = Field(..., description="Message role (user or assistant)")
content: str = Field(..., description="Message content")
timestamp: float = Field(..., description="Message timestamp (Unix timestamp)")
class ChatMessageResponse(BaseModel):
"""Response model for posting a chat message"""
thread_id: str = Field(..., description="Thread ID")
messages: List[MessageItem] = Field(..., description="All messages in thread")
class ThreadSummary(BaseModel):
"""Summary of a chat thread for list view"""
thread_id: str = Field(..., description="Thread ID")
created_at: float = Field(..., description="Thread creation timestamp")
last_message: str = Field(..., description="Last message content")
message_count: int = Field(..., description="Total number of messages")
class ThreadListResponse(BaseModel):
"""Response model for listing all threads"""
threads: List[ThreadSummary] = Field(..., description="List of thread summaries")
class ThreadDetail(BaseModel):
"""Detailed view of a single thread"""
thread_id: str = Field(..., description="Thread ID")
created_at: float = Field(..., description="Thread creation timestamp")
messages: List[MessageItem] = Field(
..., description="All messages in chronological order"
)
class DeleteResponse(BaseModel):
"""Response model for delete operations"""
message: str = Field(..., description="Confirmation message")
thread_id: str = Field(..., description="Deleted thread ID")
# --- Actual endpoints for chatbot ---
@router.post("/message/stream")
@limiter.limit("30/minute")
async def post_chat_message_stream(
*,
request: Request,
message_request: ChatMessageRequest,
currentUser: User = Depends(getCurrentUser),
) -> StreamingResponse:
"""
Post a message to a chat thread with streaming progress updates.
Creates a new thread if thread_id is not provided.
Returns Server-Sent Events (SSE) stream with status updates and final response.
"""
try:
# Generate or use existing thread_id
thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}"
logger.info(
f"User {currentUser.id} posted streaming message to thread {thread_id}"
)
return StreamingResponse(
chat_service.post_message_stream(
thread_id=thread_id,
message=message_request.message,
user=currentUser,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
except Exception as e:
logger.error(f"Error posting chat message: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to post message: {str(e)}",
)
@router.post("/message", response_model=ChatMessageResponse)
@limiter.limit("30/minute")
async def post_chat_message(
@ -88,35 +83,31 @@ async def post_chat_message(
currentUser: User = Depends(getCurrentUser),
) -> ChatMessageResponse:
"""
Post a message to a chat thread and get assistant response.
Post a message to a chat thread and get assistant response (non-streaming).
Creates a new thread if thread_id is not provided.
This endpoint will later be connected to LangGraph's checkpointer.
For streaming updates, use the /message/stream endpoint instead.
"""
try:
# Generate or use existing thread_id
thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}"
# Get current timestamp
current_time = datetime.now().timestamp()
# Create dummy message history
# In production, this will fetch from LangGraph's checkpointer
messages = [
MessageItem(
role="user", content=message_request.message, timestamp=current_time
),
MessageItem(
role="assistant",
content=f"Echo: {message_request.message} (This is a dummy response. LangGraph integration pending.)",
timestamp=current_time + 0.5,
),
]
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
return ChatMessageResponse(thread_id=thread_id, messages=messages)
response = await chat_service.post_message(
thread_id=thread_id,
message=message_request.message,
user=currentUser,
)
return response
except ValueError as e:
logger.error(f"Permission error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except Exception as e:
logger.error(f"Error posting chat message: {str(e)}")
raise HTTPException(

View file

@ -3,7 +3,7 @@ fastapi==0.104.1
websockets==12.0
uvicorn==0.23.2
python-multipart==0.0.6
httpx==0.25.0
httpx>=0.25.2
pydantic>=2.0.0 # Upgraded to v2 for LangChain compatibility
email-validator==2.0.0 # Required by Pydantic for email validation
slowapi==0.1.8 # For rate limiting
@ -113,3 +113,7 @@ psycopg2-binary==2.9.9
langchain==0.3.27
langgraph==0.6.8
langchain-core==0.3.77
langchain-anthropic==0.3.1 # For Claude models
psycopg[binary]==3.2.1 # For PostgreSQL async support (LangGraph checkpointer)
psycopg-pool==3.2.1 # Connection pooling for PostgreSQL
langgraph-checkpoint-postgres==2.0.24