From 4bfeded9d0f1760e953ef4268e1139866ab168c9 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Fri, 3 Oct 2025 16:48:33 +0200 Subject: [PATCH] feat: chatbot w/ streaming basics --- app.py | 17 + modules/datamodels/datamodelChat.py | 268 ++++++++++++---- modules/datamodels/datamodelChatbot.py | 130 ++++++++ .../sharedTools/toolStreamingStatus.py | 24 ++ modules/features/chatBot/domain/__init__.py | 1 + modules/features/chatBot/domain/chatbot.py | 301 ++++++++++++++++++ .../chatBot/domain/streaming_helper.py | 239 ++++++++++++++ modules/features/chatBot/service.py | 215 +++++++++++++ .../features/chatBot/utils/checkpointer.py | 95 ++++++ modules/features/chatBot/utils/permissions.py | 31 +- modules/routes/routeChatbot.py | 147 ++++----- requirements.txt | 6 +- 12 files changed, 1315 insertions(+), 159 deletions(-) create mode 100644 modules/datamodels/datamodelChatbot.py create mode 100644 modules/features/chatBot/chatbotTools/sharedTools/toolStreamingStatus.py create mode 100644 modules/features/chatBot/domain/__init__.py create mode 100644 modules/features/chatBot/domain/chatbot.py create mode 100644 modules/features/chatBot/domain/streaming_helper.py create mode 100644 modules/features/chatBot/service.py create mode 100644 modules/features/chatBot/utils/checkpointer.py diff --git a/app.py b/app.py index ed4e7214..23eeb645 100644 --- a/app.py +++ b/app.py @@ -240,9 +240,26 @@ instanceLabel = APP_CONFIG.get("APP_ENV_LABEL") @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Application is starting up") + + # Initialize LangGraph checkpointer + from modules.features.chatBot.utils.checkpointer import ( + initialize_checkpointer, + close_checkpointer, + ) + + try: + await initialize_checkpointer() + logger.info("LangGraph checkpointer initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize LangGraph checkpointer: {str(e)}") + # Continue startup even if checkpointer fails to initialize + eventManager.start() yield + + # Cleanup eventManager.stop() + await close_checkpointer() logger.info("Application has been shut down") diff --git a/modules/datamodels/datamodelChat.py b/modules/datamodels/datamodelChat.py index a1640b5d..62fa691a 100644 --- a/modules/datamodels/datamodelChat.py +++ b/modules/datamodels/datamodelChat.py @@ -8,10 +8,18 @@ import uuid class ChatStat(BaseModel, ModelMixin): - id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") - workflowId: Optional[str] = Field(None, description="Foreign key to workflow (for workflow stats)") - messageId: Optional[str] = Field(None, description="Foreign key to message (for message stats)") - processingTime: Optional[float] = Field(None, description="Processing time in seconds") + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Primary key" + ) + workflowId: Optional[str] = Field( + None, description="Foreign key to workflow (for workflow stats)" + ) + messageId: Optional[str] = Field( + None, description="Foreign key to message (for message stats)" + ) + processingTime: Optional[float] = Field( + None, description="Processing time in seconds" + ) tokenCount: Optional[int] = Field(None, description="Number of tokens processed") bytesSent: Optional[int] = Field(None, description="Number of bytes sent") bytesReceived: Optional[int] = Field(None, description="Number of bytes received") @@ -37,14 +45,23 @@ register_model_labels( class ChatLog(BaseModel, ModelMixin): - id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Primary key" + ) workflowId: str = Field(description="Foreign key to workflow") message: str = Field(description="Log message") type: str = Field(description="Log type (info, warning, error, etc.)") - timestamp: float = Field(default_factory=get_utc_timestamp, description="When the log entry was created (UTC timestamp in seconds)") + timestamp: float = Field( + default_factory=get_utc_timestamp, + description="When the log entry was created (UTC timestamp in seconds)", + ) status: Optional[str] = Field(None, description="Status of the log entry") - progress: Optional[float] = Field(None, description="Progress indicator (0.0 to 1.0)") - performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics") + progress: Optional[float] = Field( + None, description="Progress indicator (0.0 to 1.0)" + ) + performance: Optional[Dict[str, Any]] = Field( + None, description="Performance metrics" + ) register_model_labels( @@ -64,7 +81,9 @@ register_model_labels( class ChatDocument(BaseModel, ModelMixin): - id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Primary key" + ) messageId: str = Field(description="Foreign key to message") fileId: str = Field(description="Foreign key to file") fileName: str = Field(description="Name of the file") @@ -73,7 +92,9 @@ class ChatDocument(BaseModel, ModelMixin): roundNumber: Optional[int] = Field(None, description="Round number in workflow") taskNumber: Optional[int] = Field(None, description="Task number within round") actionNumber: Optional[int] = Field(None, description="Action number within task") - actionId: Optional[str] = Field(None, description="ID of the action that created this document") + actionId: Optional[str] = Field( + None, description="ID of the action that created this document" + ) register_model_labels( @@ -96,13 +117,19 @@ register_model_labels( class ContentMetadata(BaseModel, ModelMixin): size: int = Field(description="Content size in bytes") - pages: Optional[int] = Field(None, description="Number of pages for multi-page content") + pages: Optional[int] = Field( + None, description="Number of pages for multi-page content" + ) error: Optional[str] = Field(None, description="Processing error if any") width: Optional[int] = Field(None, description="Width in pixels for images/videos") - height: Optional[int] = Field(None, description="Height in pixels for images/videos") + height: Optional[int] = Field( + None, description="Height in pixels for images/videos" + ) colorMode: Optional[str] = Field(None, description="Color mode") fps: Optional[float] = Field(None, description="Frames per second for videos") - durationSec: Optional[float] = Field(None, description="Duration in seconds for media") + durationSec: Optional[float] = Field( + None, description="Duration in seconds for media" + ) mimeType: str = Field(description="MIME type of the content") base64Encoded: bool = Field(description="Whether the data is base64 encoded") @@ -144,7 +171,9 @@ register_model_labels( class ExtractedContent(BaseModel, ModelMixin): id: str = Field(description="Reference to source ChatDocument") - contents: List[ContentItem] = Field(default_factory=list, description="List of content items") + contents: List[ContentItem] = Field( + default_factory=list, description="List of content items" + ) register_model_labels( @@ -156,27 +185,53 @@ register_model_labels( }, ) + class ChatMessage(BaseModel, ModelMixin): - id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), description="Primary key" + ) workflowId: str = Field(description="Foreign key to workflow") - parentMessageId: Optional[str] = Field(None, description="Parent message ID for threading") - documents: List[ChatDocument] = Field(default_factory=list, description="Associated documents") - documentsLabel: Optional[str] = Field(None, description="Label for the set of documents") + parentMessageId: Optional[str] = Field( + None, description="Parent message ID for threading" + ) + documents: List[ChatDocument] = Field( + default_factory=list, description="Associated documents" + ) + documentsLabel: Optional[str] = Field( + None, description="Label for the set of documents" + ) message: Optional[str] = Field(None, description="Message content") role: str = Field(description="Role of the message sender") status: str = Field(description="Status of the message (first, step, last)") - sequenceNr: int = Field(description="Sequence number of the message (set automatically)") - publishedAt: float = Field(default_factory=get_utc_timestamp, description="When the message was published (UTC timestamp in seconds)") + sequenceNr: int = Field( + description="Sequence number of the message (set automatically)" + ) + publishedAt: float = Field( + default_factory=get_utc_timestamp, + description="When the message was published (UTC timestamp in seconds)", + ) stats: Optional[ChatStat] = Field(None, description="Statistics for this message") - success: Optional[bool] = Field(None, description="Whether the message processing was successful") - actionId: Optional[str] = Field(None, description="ID of the action that produced this message") - actionMethod: Optional[str] = Field(None, description="Method of the action that produced this message") - actionName: Optional[str] = Field(None, description="Name of the action that produced this message") + success: Optional[bool] = Field( + None, description="Whether the message processing was successful" + ) + actionId: Optional[str] = Field( + None, description="ID of the action that produced this message" + ) + actionMethod: Optional[str] = Field( + None, description="Method of the action that produced this message" + ) + actionName: Optional[str] = Field( + None, description="Name of the action that produced this message" + ) roundNumber: Optional[int] = Field(None, description="Round number in workflow") taskNumber: Optional[int] = Field(None, description="Task number within round") actionNumber: Optional[int] = Field(None, description="Action number within task") - taskProgress: Optional[str] = Field(None, description="Task progress status: pending, running, success, fail, retry") - actionProgress: Optional[str] = Field(None, description="Action progress status: pending, running, success, fail") + taskProgress: Optional[str] = Field( + None, description="Task progress status: pending, running, success, fail, retry" + ) + actionProgress: Optional[str] = Field( + None, description="Action progress status: pending, running, success, fail" + ) register_model_labels( @@ -208,31 +263,135 @@ register_model_labels( class ChatWorkflow(BaseModel, ModelMixin): - id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", frontend_type="text", frontend_readonly=True, frontend_required=False) - mandateId: str = Field(description="ID of the mandate this workflow belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False) - status: str = Field(description="Current status of the workflow", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[ - {"value": "running", "label": {"en": "Running", "fr": "En cours"}}, - {"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}}, - {"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}}, - {"value": "error", "label": {"en": "Error", "fr": "Erreur"}}, - ]) - name: Optional[str] = Field(None, description="Name of the workflow", frontend_type="text", frontend_readonly=False, frontend_required=True) - currentRound: int = Field(description="Current round number", frontend_type="integer", frontend_readonly=True, frontend_required=False) - currentTask: int = Field(default=0, description="Current task number", frontend_type="integer", frontend_readonly=True, frontend_required=False) - currentAction: int = Field(default=0, description="Current action number", frontend_type="integer", frontend_readonly=True, frontend_required=False) - totalTasks: int = Field(default=0, description="Total number of tasks in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False) - totalActions: int = Field(default=0, description="Total number of actions in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False) - lastActivity: float = Field(default_factory=get_utc_timestamp, description="Timestamp of last activity (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False) - startedAt: float = Field(default_factory=get_utc_timestamp, description="When the workflow started (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False) - logs: List[ChatLog] = Field(default_factory=list, description="Workflow logs", frontend_type="text", frontend_readonly=True, frontend_required=False) - messages: List[ChatMessage] = Field(default_factory=list, description="Messages in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False) - stats: Optional[ChatStat] = Field(None, description="Workflow statistics", frontend_type="text", frontend_readonly=True, frontend_required=False) - tasks: list = Field(default_factory=list, description="List of tasks in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False) - workflowMode: str = Field(default="Actionplan", description="Workflow mode selector", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[ - {"value": "Actionplan", "label": {"en": "Action Plan", "fr": "Plan d'actions"}}, - {"value": "React", "label": {"en": "React", "fr": "Réactif"}}, - ]) - maxSteps: int = Field(default=5, description="Maximum number of iterations in react mode", frontend_type="integer", frontend_readonly=False, frontend_required=False) + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Primary key", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + mandateId: str = Field( + description="ID of the mandate this workflow belongs to", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + status: str = Field( + description="Current status of the workflow", + frontend_type="select", + frontend_readonly=False, + frontend_required=False, + frontend_options=[ + {"value": "running", "label": {"en": "Running", "fr": "En cours"}}, + {"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}}, + {"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}}, + {"value": "error", "label": {"en": "Error", "fr": "Erreur"}}, + ], + ) + name: Optional[str] = Field( + None, + description="Name of the workflow", + frontend_type="text", + frontend_readonly=False, + frontend_required=True, + ) + currentRound: int = Field( + description="Current round number", + frontend_type="integer", + frontend_readonly=True, + frontend_required=False, + ) + currentTask: int = Field( + default=0, + description="Current task number", + frontend_type="integer", + frontend_readonly=True, + frontend_required=False, + ) + currentAction: int = Field( + default=0, + description="Current action number", + frontend_type="integer", + frontend_readonly=True, + frontend_required=False, + ) + totalTasks: int = Field( + default=0, + description="Total number of tasks in the workflow", + frontend_type="integer", + frontend_readonly=True, + frontend_required=False, + ) + totalActions: int = Field( + default=0, + description="Total number of actions in the workflow", + frontend_type="integer", + frontend_readonly=True, + frontend_required=False, + ) + lastActivity: float = Field( + default_factory=get_utc_timestamp, + description="Timestamp of last activity (UTC timestamp in seconds)", + frontend_type="timestamp", + frontend_readonly=True, + frontend_required=False, + ) + startedAt: float = Field( + default_factory=get_utc_timestamp, + description="When the workflow started (UTC timestamp in seconds)", + frontend_type="timestamp", + frontend_readonly=True, + frontend_required=False, + ) + logs: List[ChatLog] = Field( + default_factory=list, + description="Workflow logs", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + messages: List[ChatMessage] = Field( + default_factory=list, + description="Messages in the workflow", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + stats: Optional[ChatStat] = Field( + None, + description="Workflow statistics", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + tasks: list = Field( + default_factory=list, + description="List of tasks in the workflow", + frontend_type="text", + frontend_readonly=True, + frontend_required=False, + ) + workflowMode: str = Field( + default="Actionplan", + description="Workflow mode selector", + frontend_type="select", + frontend_readonly=False, + frontend_required=False, + frontend_options=[ + { + "value": "Actionplan", + "label": {"en": "Action Plan", "fr": "Plan d'actions"}, + }, + {"value": "React", "label": {"en": "React", "fr": "Réactif"}}, + ], + ) + maxSteps: int = Field( + default=5, + description="Maximum number of iterations in react mode", + frontend_type="integer", + frontend_readonly=False, + frontend_required=False, + ) register_model_labels( @@ -278,7 +437,10 @@ register_model_labels( "completed_tasks": {"en": "Completed Tasks", "fr": "Tâches terminées"}, "total_tasks": {"en": "Total Tasks", "fr": "Total des tâches"}, "execution_time": {"en": "Execution Time", "fr": "Temps d'exécution"}, - "final_results_count": {"en": "Final Results Count", "fr": "Nombre de résultats finaux"}, + "final_results_count": { + "en": "Final Results Count", + "fr": "Nombre de résultats finaux", + }, "error": {"en": "Error", "fr": "Erreur"}, "phase": {"en": "Phase", "fr": "Phase"}, }, @@ -300,5 +462,3 @@ register_model_labels( "userLanguage": {"en": "User Language", "fr": "Langue de l'utilisateur"}, }, ) - - diff --git a/modules/datamodels/datamodelChatbot.py b/modules/datamodels/datamodelChatbot.py new file mode 100644 index 00000000..906757b7 --- /dev/null +++ b/modules/datamodels/datamodelChatbot.py @@ -0,0 +1,130 @@ +"""Chatbot API models for request/response handling.""" + +from typing import List, Optional +from pydantic import BaseModel, Field +from modules.shared.attributeUtils import register_model_labels, ModelMixin + + +# Chatbot API Models +class MessageItem(BaseModel, ModelMixin): + """Individual message in a thread""" + + role: str = Field(..., description="Message role (user or assistant)") + content: str = Field(..., description="Message content") + timestamp: float = Field(..., description="Message timestamp (Unix timestamp)") + + +class ChatMessageRequest(BaseModel, ModelMixin): + """Request model for posting a chat message""" + + thread_id: Optional[str] = Field( + None, description="Thread ID (creates new thread if not provided)" + ) + message: str = Field(..., description="User message content") + + +class ChatMessageResponse(BaseModel, ModelMixin): + """Response model for posting a chat message""" + + thread_id: str = Field(..., description="Thread ID") + messages: List[MessageItem] = Field(..., description="All messages in thread") + + +class ThreadSummary(BaseModel, ModelMixin): + """Summary of a chat thread for list view""" + + thread_id: str = Field(..., description="Thread ID") + created_at: float = Field(..., description="Thread creation timestamp") + last_message: str = Field(..., description="Last message content") + message_count: int = Field(..., description="Total number of messages") + + +class ThreadListResponse(BaseModel, ModelMixin): + """Response model for listing all threads""" + + threads: List[ThreadSummary] = Field(..., description="List of thread summaries") + + +class ThreadDetail(BaseModel, ModelMixin): + """Detailed view of a single thread""" + + thread_id: str = Field(..., description="Thread ID") + created_at: float = Field(..., description="Thread creation timestamp") + messages: List[MessageItem] = Field( + ..., description="All messages in chronological order" + ) + + +class DeleteResponse(BaseModel, ModelMixin): + """Response model for delete operations""" + + message: str = Field(..., description="Confirmation message") + thread_id: str = Field(..., description="Deleted thread ID") + + +# Register model labels for internationalization +register_model_labels( + "MessageItem", + {"en": "Message Item", "fr": "Élément de message"}, + { + "role": {"en": "Role", "fr": "Rôle"}, + "content": {"en": "Content", "fr": "Contenu"}, + "timestamp": {"en": "Timestamp", "fr": "Horodatage"}, + }, +) + +register_model_labels( + "ChatMessageRequest", + {"en": "Chat Message Request", "fr": "Demande de message de chat"}, + { + "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, + "message": {"en": "Message", "fr": "Message"}, + }, +) + +register_model_labels( + "ChatMessageResponse", + {"en": "Chat Message Response", "fr": "Réponse du message de chat"}, + { + "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, + "messages": {"en": "Messages", "fr": "Messages"}, + }, +) + +register_model_labels( + "ThreadSummary", + {"en": "Thread Summary", "fr": "Résumé du fil"}, + { + "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, + "created_at": {"en": "Created At", "fr": "Créé le"}, + "last_message": {"en": "Last Message", "fr": "Dernier message"}, + "message_count": {"en": "Message Count", "fr": "Nombre de messages"}, + }, +) + +register_model_labels( + "ThreadListResponse", + {"en": "Thread List Response", "fr": "Réponse de liste de fils"}, + { + "threads": {"en": "Threads", "fr": "Fils"}, + }, +) + +register_model_labels( + "ThreadDetail", + {"en": "Thread Detail", "fr": "Détail du fil"}, + { + "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, + "created_at": {"en": "Created At", "fr": "Créé le"}, + "messages": {"en": "Messages", "fr": "Messages"}, + }, +) + +register_model_labels( + "DeleteResponse", + {"en": "Delete Response", "fr": "Réponse de suppression"}, + { + "message": {"en": "Message", "fr": "Message"}, + "thread_id": {"en": "Thread ID", "fr": "ID du fil"}, + }, +) diff --git a/modules/features/chatBot/chatbotTools/sharedTools/toolStreamingStatus.py b/modules/features/chatBot/chatbotTools/sharedTools/toolStreamingStatus.py new file mode 100644 index 00000000..f2587be8 --- /dev/null +++ b/modules/features/chatBot/chatbotTools/sharedTools/toolStreamingStatus.py @@ -0,0 +1,24 @@ +"""Tool for sending streaming status updates to users.""" + +from langchain_core.tools import tool + + +@tool +def send_streaming_message(message: str) -> str: + """Send a streaming message to the user to provide updates during processing. + + Use this tool to send short status updates to the user while you are working + on their request. This helps keep the user informed about what you are doing. + + Args: + message: A short message describing what you are currently doing. + Examples: "Searching database for relevant information..." + "Analyzing search results..." + "Processing your request..." + + Returns: + A confirmation that the message was sent. + """ + # This tool doesn't actually do anything - it's just for the AI to signal + # what it's doing to the frontend via the tool call mechanism + return f"Status update sent: {message}" diff --git a/modules/features/chatBot/domain/__init__.py b/modules/features/chatBot/domain/__init__.py new file mode 100644 index 00000000..abd60dca --- /dev/null +++ b/modules/features/chatBot/domain/__init__.py @@ -0,0 +1 @@ +"""Domain logic for chatbot functionality.""" diff --git a/modules/features/chatBot/domain/chatbot.py b/modules/features/chatBot/domain/chatbot.py new file mode 100644 index 00000000..56129097 --- /dev/null +++ b/modules/features/chatBot/domain/chatbot.py @@ -0,0 +1,301 @@ +"""Chatbot domain logic with LangGraph integration.""" + +from dataclasses import dataclass +from typing import Annotated, AsyncIterator, Any +import logging + +from pydantic import BaseModel +from langchain_core.messages import ( + BaseMessage, + HumanMessage, + SystemMessage, + trim_messages, +) +from langgraph.graph.message import add_messages +from langgraph.graph import StateGraph, START, END +from langgraph.graph.state import CompiledStateGraph +from langgraph.prebuilt import ToolNode +from langchain_anthropic import ChatAnthropic + +from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper +from modules.features.chatBot.utils.toolRegistry import get_registry +from modules.shared.configuration import APP_CONFIG + +logger = logging.getLogger(__name__) + + +class ChatState(BaseModel): + """Represents the state of a chat session.""" + + messages: Annotated[list[BaseMessage], add_messages] + + +def get_langchain_model(*, model_name: str) -> ChatAnthropic: + """Map permission model names to LangChain ChatAnthropic models. + + Args: + model_name: The model name from permissions (e.g., "claude_4_5") + + Returns: + Configured ChatAnthropic instance + + Raises: + ValueError: If the model name is not supported + """ + # Model name mapping + model_mapping = { + "claude_4_5": "claude-4-5-sonnet", + # Add more mappings as needed + } + + anthropic_model = model_mapping.get(model_name) + if not anthropic_model: + logger.warning( + f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet" + ) + anthropic_model = "claude-4-5-sonnet" + + return ChatAnthropic( + model=anthropic_model, + api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"), + temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)), + max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)), + ) + + +@dataclass +class Chatbot: + """Represents a chatbot with LangGraph integration.""" + + model: Any + memory: Any + app: Any = None + system_prompt: str = "You are a helpful assistant." + context_window_size: int = 100000 + + @classmethod + async def create( + cls, + *, + model: Any, + memory: Any, + system_prompt: str, + tools: list, + context_window_size: int = 100000, + ) -> "Chatbot": + """Factory method to create and configure a Chatbot instance. + + Args: + model: The chat model to use. + memory: The chat memory checkpointer to use. + system_prompt: The system prompt to initialize the chatbot. + tools: List of LangChain tools the chatbot can use. + context_window_size: Maximum tokens for context window. + + Returns: + A configured Chatbot instance. + """ + instance = cls( + model=model, + memory=memory, + system_prompt=system_prompt, + context_window_size=context_window_size, + ) + instance.app = instance._build_app(memory=memory, tools=tools) + return instance + + def _build_app( + self, *, memory: Any, tools: list + ) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]: + """Builds the chatbot application workflow using LangGraph. + + Args: + memory: The chat memory checkpointer to use. + tools: The list of tools the chatbot can use. + + Returns: + A compiled state graph representing the chatbot application. + """ + llm_with_tools = self.model.bind_tools(tools=tools) + + def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]: + """Selects a window of messages that fit within the context window size. + + Args: + msgs: The list of messages to select from. + + Returns: + A list of messages that fit within the context window size. + """ + + def approx_counter(items: list[BaseMessage]) -> int: + """Approximate token counter for messages. + + Args: + items: List of messages to count tokens for. + + Returns: + Approximate number of tokens in the messages. + """ + return sum(len(getattr(m, "content", "") or "") for m in items) + + return trim_messages( + msgs, + strategy="last", + token_counter=approx_counter, + max_tokens=self.context_window_size, + start_on="human", + end_on=("human", "tool"), + include_system=True, + ) + + def agent_node(state: ChatState) -> dict: + """Agent node for the chatbot workflow. + + Args: + state: The current chat state. + + Returns: + The updated chat state after processing. + """ + # Select the message window to fit in context (trim if needed) + window = select_window(state.messages) + + # Ensure the system prompt is present at the start + if not window or not isinstance(window[0], SystemMessage): + window = [SystemMessage(content=self.system_prompt)] + window + + # Call the LLM with tools + response = llm_with_tools.invoke(window) + + # Return the new state + return {"messages": [response]} + + def should_continue(state: ChatState) -> str: + """Determines whether to continue the workflow or end it. + + This conditional edge is called after the agent node to decide + whether to continue to the tools node (if the last message contains + tool calls) or to end the workflow (if no tool calls are present). + + Args: + state: The current chat state. + + Returns: + The next node to transition to ("tools" or END). + """ + # Get the last message + last_message = state.messages[-1] + + # Check if the last message contains tool calls + # If so, continue to the tools node; otherwise, end the workflow + return "tools" if getattr(last_message, "tool_calls", None) else END + + # Compose the workflow + workflow = StateGraph(ChatState) + workflow.add_node("agent", agent_node) + workflow.add_node("tools", ToolNode(tools=tools)) + workflow.add_edge(START, "agent") + workflow.add_conditional_edges("agent", should_continue) + workflow.add_edge("tools", "agent") + return workflow.compile(checkpointer=memory) + + async def chat( + self, *, message: str, chat_id: str = "default" + ) -> list[BaseMessage]: + """Processes a chat message and returns the chat history. + + Args: + message: The user message to process. + chat_id: The chat thread ID. + + Returns: + The list of messages in the chat history. + """ + # Set the right thread ID for memory + config = {"configurable": {"thread_id": chat_id}} + + # Single-turn chat (non-streaming) + result = await self.app.ainvoke( + {"messages": [HumanMessage(content=message)]}, config=config + ) + + # Extract and return the messages from the result + return result["messages"] + + async def stream_events( + self, *, message: str, chat_id: str = "default" + ) -> AsyncIterator[dict]: + """Stream UI-focused events using astream_events v2. + + Args: + message: The user message to process. + chat_id: Logical thread identifier; forwarded in the runnable config so + memory and tools are scoped per thread. + + Yields: + dict: One of: + - ``{"type": "status", "label": str}`` for short progress updates. + - ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}`` + where ``chat_history`` only includes ``user``/``assistant`` roles. + - ``{"type": "error", "message": str}`` if an exception occurs. + """ + # Thread-aware config for LangGraph/LangChain + config = {"configurable": {"thread_id": chat_id}} + + def _is_root(ev: dict) -> bool: + """Return True if the event is from the root run (v2: empty parent_ids).""" + return not ev.get("parent_ids") + + try: + async for event in self.app.astream_events( + {"messages": [HumanMessage(content=message)]}, + config=config, + version="v2", + ): + etype = event.get("event") + ename = event.get("name") or "" + edata = event.get("data") or {} + + # Stream human-readable progress via the special send_streaming_message tool + if etype == "on_tool_start" and ename == "send_streaming_message": + tool_in = edata.get("input") or {} + msg = tool_in.get("message") + if isinstance(msg, str) and msg.strip(): + yield {"type": "status", "label": msg.strip()} + continue + + # Emit the final payload when the root run finishes + if etype == "on_chain_end" and _is_root(event): + output_obj = edata.get("output") + + # Extract message list from the graph's final output + final_msgs = ChatStreamingHelper.extract_messages_from_output( + output_obj=output_obj + ) + + # Normalize for the frontend (only user/assistant with text content) + chat_history_payload: list[dict] = [] + for m in final_msgs: + if isinstance(m, BaseMessage): + d = ChatStreamingHelper.message_to_dict(msg=m) + elif isinstance(m, dict): + d = ChatStreamingHelper.dict_message_to_dict(obj=m) + else: + continue + if d.get("role") in ("user", "assistant") and d.get("content"): + chat_history_payload.append(d) + + yield { + "type": "final", + "response": { + "thread": chat_id, + "chat_history": chat_history_payload, + }, + } + return + + except Exception as exc: + # Emit a single error envelope and end the stream + logger.error(f"Error in stream_events: {str(exc)}", exc_info=True) + yield {"type": "error", "message": f"Error processing request: {exc}"} diff --git a/modules/features/chatBot/domain/streaming_helper.py b/modules/features/chatBot/domain/streaming_helper.py new file mode 100644 index 00000000..f8c73b45 --- /dev/null +++ b/modules/features/chatBot/domain/streaming_helper.py @@ -0,0 +1,239 @@ +"""Streaming helper utilities for chat message processing and normalization.""" + +from typing import Any, Dict, List, Literal, Mapping, Optional + +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) + +Role = Literal["user", "assistant", "system", "tool"] + + +class ChatStreamingHelper: + """Pure helper methods for streaming and message normalization. + + This class provides static utility methods for converting between different + message formats, extracting content, and normalizing message structures + for streaming chat applications. + """ + + @staticmethod + def role_from_message(*, msg: BaseMessage) -> Role: + """Extract the role from a BaseMessage instance. + + Args: + msg: The BaseMessage instance to extract the role from. + + Returns: + The role as a string literal: "user", "assistant", "system", or "tool". + Defaults to "assistant" if the message type is not recognized. + + Examples: + >>> from langchain_core.messages import HumanMessage + >>> msg = HumanMessage(content="Hello") + >>> ChatStreamingHelper.role_from_message(msg=msg) + 'user' + """ + if isinstance(msg, HumanMessage): + return "user" + if isinstance(msg, AIMessage): + return "assistant" + if isinstance(msg, SystemMessage): + return "system" + if isinstance(msg, ToolMessage): + return "tool" + return getattr(msg, "role", "assistant") + + @staticmethod + def flatten_content(*, content: Any) -> str: + """Convert complex content structures to plain text. + + This method handles various content formats including strings, lists of + content parts, and dictionaries with text fields. It's designed to + normalize content from different message sources into a consistent + plain text format. + + Args: + content: The content to flatten. Can be: + - str: Returned as-is after stripping whitespace + - list: Each item processed and joined with newlines + - dict: Text extracted from "text" or "content" fields + - None: Returns empty string + - Any other type: Converted to string + + Returns: + The flattened content as a plain text string with whitespace stripped. + + Examples: + >>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}] + >>> ChatStreamingHelper.flatten_content(content=content) + 'Hello + nworld' + + >>> content = {"text": "Simple message"} + >>> ChatStreamingHelper.flatten_content(content=content) + 'Simple message' + """ + if content is None: + return "" + if isinstance(content, str): + return content.strip() + if isinstance(content, list): + parts: List[str] = [] + for part in content: + if isinstance(part, dict): + if "text" in part and isinstance(part["text"], str): + parts.append(part["text"]) + elif part.get("type") == "text" and isinstance( + part.get("text"), str + ): + parts.append(part["text"]) + elif "content" in part and isinstance(part["content"], str): + parts.append(part["content"]) + else: + # Fallback for unknown dictionary structures + val = part.get("value") + if isinstance(val, str): + parts.append(val) + else: + parts.append(str(part)) + return "\n".join(p.strip() for p in parts if p is not None) + if isinstance(content, dict): + if "text" in content and isinstance(content["text"], str): + return content["text"].strip() + if "content" in content and isinstance(content["content"], str): + return content["content"].strip() + return str(content).strip() + + @staticmethod + def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]: + """Convert a BaseMessage instance to a dictionary for streaming output. + + This method normalizes BaseMessage instances into a consistent dictionary + format suitable for JSON serialization and streaming to clients. + + Args: + msg: The BaseMessage instance to convert. + + Returns: + A dictionary containing: + - "role": The message role (user, assistant, system, tool) + - "content": The flattened message content as plain text + - "tool_calls": Tool calls if present (optional) + - "name": Message name if present (optional) + + Examples: + >>> from langchain_core.messages import HumanMessage + >>> msg = HumanMessage(content="Hello there") + >>> result = ChatStreamingHelper.message_to_dict(msg=msg) + >>> result["role"] + 'user' + >>> result["content"] + 'Hello there' + """ + payload: Dict[str, Any] = { + "role": ChatStreamingHelper.role_from_message(msg=msg), + "content": ChatStreamingHelper.flatten_content( + content=getattr(msg, "content", "") + ), + } + tool_calls = getattr(msg, "tool_calls", None) + if tool_calls: + payload["tool_calls"] = tool_calls + name = getattr(msg, "name", None) + if name: + payload["name"] = name + return payload + + @staticmethod + def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]: + """Convert a dictionary-shaped message to a normalized dictionary. + + This method handles messages that come from serialized state and are + represented as dictionaries rather than BaseMessage instances. It + normalizes various dictionary formats into a consistent structure. + + Args: + obj: The dictionary-shaped message to convert. Expected to contain + fields like "role", "type", "content", "text", etc. + + Returns: + A normalized dictionary containing: + - "role": The message role (user, assistant, system, tool) + - "content": The flattened message content as plain text + - "tool_calls": Tool calls if present (optional) + - "name": Message name if present (optional) + + Examples: + >>> obj = {"type": "human", "content": "Hello"} + >>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj) + >>> result["role"] + 'user' + >>> result["content"] + 'Hello' + """ + role: Optional[str] = obj.get("role") + if not role: + # Handle alternative type field mappings + typ = obj.get("type") + if typ in ("human", "user"): + role = "user" + elif typ in ("ai", "assistant"): + role = "assistant" + elif typ in ("system",): + role = "system" + elif typ in ("tool", "function"): + role = "tool" + + content = obj.get("content") + if content is None and "text" in obj: + content = obj["text"] + + out: Dict[str, Any] = { + "role": role or "assistant", + "content": ChatStreamingHelper.flatten_content(content=content), + } + if "tool_calls" in obj: + out["tool_calls"] = obj["tool_calls"] + if obj.get("name"): + out["name"] = obj["name"] + return out + + @staticmethod + def extract_messages_from_output(*, output_obj: Any) -> List[Any]: + """Extract messages from LangGraph output objects. + + This method handles various output formats from LangGraph execution, + extracting the messages list from different possible structures. + + Args: + output_obj: The output object from LangGraph execution. Can be: + - An object with a "messages" attribute + - A dictionary with a "messages" key + - Any other object (returns empty list) + + Returns: + A list of extracted messages, or an empty list if no messages + are found or if the output object is None. + + Examples: + >>> output = {"messages": [{"role": "user", "content": "Hello"}]} + >>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output) + >>> len(messages) + 1 + """ + if output_obj is None: + return [] + + # Try to parse dicts first + if isinstance(output_obj, dict): + msgs = output_obj.get("messages") + return msgs if isinstance(msgs, list) else [] + + # Then try to get messages attribute + msgs = getattr(output_obj, "messages", None) + return msgs if isinstance(msgs, list) else [] diff --git a/modules/features/chatBot/service.py b/modules/features/chatBot/service.py new file mode 100644 index 00000000..1899d8e9 --- /dev/null +++ b/modules/features/chatBot/service.py @@ -0,0 +1,215 @@ +"""Service layer for chatbot functionality.""" + +import json +import logging +from typing import AsyncIterator, List + +from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model +from modules.features.chatBot.utils.checkpointer import get_checkpointer +from modules.features.chatBot.utils.toolRegistry import get_registry +from modules.features.chatBot.utils import permissions +from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse +from modules.datamodels.datamodelUam import User + +from langchain_core.messages import HumanMessage, AIMessage +from modules.shared.configuration import APP_CONFIG + +logger = logging.getLogger(__name__) + + +async def post_message( + *, + thread_id: str, + message: str, + user: User, +) -> ChatMessageResponse: + """Post a chat message to the chatbot and return the response. + + Args: + thread_id: The unique identifier for the chat thread. + message: The content of the chat message. + user: The current user. + + Returns: + The response containing the full chat message history and thread ID. + """ + logger.info(f"User {user.id} posted message to thread {thread_id}") + + # Get user permissions + tool_ids = permissions.get_chatbot_tools(user_id=user.id) + if not tool_ids: + raise ValueError("User does not have permission to use any chatbot tools") + + model_name = permissions.get_chatbot_model(user_id=user.id) + system_prompt = permissions.get_system_prompt(user_id=user.id) + + # Get tools from registry + registry = get_registry() + tools = registry.get_tool_instances(tool_ids=tool_ids) + + # Get model and checkpointer + model = get_langchain_model(model_name=model_name) + checkpointer = get_checkpointer() + + # Get context window size from config + context_window_size = int( + APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) + ) + + # Create chatbot instance + chatbot = await Chatbot.create( + model=model, + memory=checkpointer, + system_prompt=system_prompt, + tools=tools, + context_window_size=context_window_size, + ) + + # Send message to chatbot + response = await chatbot.chat(message=message, chat_id=thread_id) + + # Parse the response to the correct format + messages = [] + for msg in response: + # Determine the role of the message + if isinstance(msg, HumanMessage): + role = "user" + elif isinstance(msg, AIMessage): + role = "assistant" + else: + continue # Skip any other message types + + # Skip messages that are structured content, such as tool calls + if not isinstance(msg.content, str): + continue + + # Append message to chat history + item = MessageItem( + role=role, + content=msg.content.strip(), + timestamp=0.0, # TODO: Add proper timestamp handling + ) + messages.append(item) + + return ChatMessageResponse(thread_id=thread_id, messages=messages) + + +async def post_message_stream( + *, + thread_id: str, + message: str, + user: User, +) -> AsyncIterator[str]: + """Post a chat message to the chatbot and stream progress updates (SSE). + + Args: + thread_id: The unique identifier for the chat thread. + message: The content of the chat message. + user: The current user. + + Yields: + Server-Sent Events formatted strings containing status updates and final response. + """ + logger.info(f"User {user.id} streaming message to thread {thread_id}") + + try: + # Get user permissions + tool_ids = permissions.get_chatbot_tools(user_id=user.id) + if not tool_ids: + yield ( + "data: " + + json.dumps( + { + "type": "error", + "message": "User does not have permission to use any chatbot tools", + } + ) + + "\n\n" + ) + return + + model_name = permissions.get_chatbot_model(user_id=user.id) + system_prompt = permissions.get_system_prompt(user_id=user.id) + + # Get tools from registry + registry = get_registry() + tools = registry.get_tool_instances(tool_ids=tool_ids) + + # Get model and checkpointer + model = get_langchain_model(model_name=model_name) + checkpointer = get_checkpointer() + + # Get context window size from config + context_window_size = int( + APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000) + ) + + # Create chatbot instance + chatbot = await Chatbot.create( + model=model, + memory=checkpointer, + system_prompt=system_prompt, + tools=tools, + context_window_size=context_window_size, + ) + + # Stream events from chatbot + async for event in chatbot.stream_events(message=message, chat_id=thread_id): + etype = event.get("type") + + # Forward status updates + if etype == "status": + yield f"data: {json.dumps({'type': 'status', 'label': event.get('label')})}\n\n" + continue + + # Forward final response + if etype == "final": + response_from_event = event.get("response") or {} + + # Use the chat history from the final event (already normalized by stream_events) + chat_history_payload = response_from_event.get("chat_history", []) + if isinstance(chat_history_payload, list): + # Convert to MessageItem format + items: List[MessageItem] = [] + for it in chat_history_payload: + role = it.get("role") + content = it.get("content", "") + if role in ("user", "assistant") and content: + items.append( + MessageItem( + role=role, + content=content, + timestamp=0.0, # TODO: Add proper timestamp handling + ) + ) + + response = ChatMessageResponse(thread_id=thread_id, messages=items) + # Yield the final response and exit + yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n" + return + else: + # Unexpected payload format - log warning and return empty history + logger.warning( + f"Unexpected chat_history format in final event: {type(chat_history_payload)}" + ) + response = ChatMessageResponse(thread_id=thread_id, messages=[]) + yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n" + return + + # Forward error events + if etype == "error": + yield f"data: {json.dumps(event)}\n\n" + return + + except Exception as e: + logger.error(f"Error in streaming chat: {str(e)}", exc_info=True) + yield ( + "data: " + + json.dumps( + { + "type": "error", + "message": "An error occurred while processing your request.", + } + ) + + "\n\n" + ) diff --git a/modules/features/chatBot/utils/checkpointer.py b/modules/features/chatBot/utils/checkpointer.py new file mode 100644 index 00000000..0aebbda6 --- /dev/null +++ b/modules/features/chatBot/utils/checkpointer.py @@ -0,0 +1,95 @@ +"""PostgreSQL checkpointer utilities for LangGraph memory.""" + +import logging +from typing import Optional + +from langgraph.checkpoint.postgres import PostgresSaver +from psycopg_pool import AsyncConnectionPool +from modules.shared.configuration import APP_CONFIG + +logger = logging.getLogger(__name__) + +# Global checkpointer instance +_checkpointer_instance: Optional[PostgresSaver] = None +_connection_pool: Optional[AsyncConnectionPool] = None + + +async def initialize_checkpointer() -> None: + """Initialize the PostgreSQL checkpointer for LangGraph. + + This should be called during application startup. + Creates a connection pool and PostgresSaver instance. + """ + global _checkpointer_instance, _connection_pool + + if _checkpointer_instance is not None: + logger.info("Checkpointer already initialized") + return + + try: + # Get database configuration from environment + host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost") + database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat") + user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev") + password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET") + port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432") + + # Build connection string + connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}" + + # Create async connection pool + _connection_pool = AsyncConnectionPool( + conninfo=connection_string, + min_size=2, + max_size=10, + ) + + # Initialize the connection pool + await _connection_pool.open() + + # Create PostgresSaver with the pool + _checkpointer_instance = PostgresSaver(_connection_pool) + + # Setup the checkpointer (creates tables if needed) + await _checkpointer_instance.setup() + + logger.info("PostgreSQL checkpointer initialized successfully") + + except Exception as e: + logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}") + raise + + +async def close_checkpointer() -> None: + """Close the checkpointer and connection pool. + + This should be called during application shutdown. + """ + global _checkpointer_instance, _connection_pool + + if _connection_pool is not None: + try: + await _connection_pool.close() + logger.info("PostgreSQL checkpointer connection pool closed") + except Exception as e: + logger.error(f"Error closing checkpointer connection pool: {str(e)}") + + _checkpointer_instance = None + _connection_pool = None + + +def get_checkpointer() -> PostgresSaver: + """Get the global PostgreSQL checkpointer instance. + + Returns: + The initialized PostgresSaver instance + + Raises: + RuntimeError: If checkpointer is not initialized + """ + if _checkpointer_instance is None: + raise RuntimeError( + "PostgreSQL checkpointer not initialized. " + "Call initialize_checkpointer() during application startup." + ) + return _checkpointer_instance diff --git a/modules/features/chatBot/utils/permissions.py b/modules/features/chatBot/utils/permissions.py index f8e57a30..d2fb4d65 100644 --- a/modules/features/chatBot/utils/permissions.py +++ b/modules/features/chatBot/utils/permissions.py @@ -12,36 +12,15 @@ from modules.features.chatBot.utils.toolRegistry import get_registry # TODO: Replace these mock implementations with actual database queries -def get_allowed_tools(*, user_id: str) -> list[str]: - """Get list of tool IDs that a user is allowed to use. - - This is a mock implementation that returns all available tools - regardless of user_id. In production, this will query the database - for user-specific permissions. - - Args: - user_id: The unique identifier of the user - - Returns: - List of tool IDs (e.g., ["shared.tavily_search", "customer.query_althaus_database"]) - """ +def get_chatbot_tools(*, user_id: str) -> list[str]: + """Get list of tool IDs that the chatbot can use for a given user.""" registry = get_registry() return registry.list_tool_ids() -def get_allowed_models(*, user_id: str) -> list[str]: - """Get list of AI models that a user is allowed to use. - - This is a mock implementation that returns a fixed list of models. - In production, this will query the database for user-specific model permissions. - - Args: - user_id: The unique identifier of the user - - Returns: - List of model identifiers (e.g., ["gpt-5", "claude-4-5"]) - """ - return ["gpt-5", "claude-4-5"] +def get_chatbot_model(*, user_id: str) -> str: + """Gets the chatbot model(s) a user is allowed to use.""" + return "claude_4_5" def get_system_prompt(*, user_id: str) -> str: diff --git a/modules/routes/routeChatbot.py b/modules/routes/routeChatbot.py index ff5fa9f4..01617582 100644 --- a/modules/routes/routeChatbot.py +++ b/modules/routes/routeChatbot.py @@ -1,13 +1,23 @@ -from pydantic import BaseModel, Field from fastapi import APIRouter, Depends, HTTPException, status from fastapi.requests import Request +from fastapi.responses import StreamingResponse from typing import Any, Dict, List, Optional from datetime import datetime import logging import uuid from modules.datamodels.datamodelUam import User +from modules.datamodels.datamodelChatbot import ( + ChatMessageRequest, + MessageItem, + ChatMessageResponse, + ThreadSummary, + ThreadListResponse, + ThreadDetail, + DeleteResponse, +) from modules.security.auth import getCurrentUser, limiter +from modules.features.chatBot import service as chat_service logger = logging.getLogger(__name__) @@ -17,68 +27,53 @@ router = APIRouter( responses={404: {"description": "Not found"}}, ) -# --- Pydantic models for requests and responses --- - - -class ChatMessageRequest(BaseModel): - """Request model for posting a chat message""" - - thread_id: Optional[str] = Field( - None, description="Thread ID (creates new thread if not provided)" - ) - message: str = Field(..., description="User message content") - - -class MessageItem(BaseModel): - """Individual message in a thread""" - - role: str = Field(..., description="Message role (user or assistant)") - content: str = Field(..., description="Message content") - timestamp: float = Field(..., description="Message timestamp (Unix timestamp)") - - -class ChatMessageResponse(BaseModel): - """Response model for posting a chat message""" - - thread_id: str = Field(..., description="Thread ID") - messages: List[MessageItem] = Field(..., description="All messages in thread") - - -class ThreadSummary(BaseModel): - """Summary of a chat thread for list view""" - - thread_id: str = Field(..., description="Thread ID") - created_at: float = Field(..., description="Thread creation timestamp") - last_message: str = Field(..., description="Last message content") - message_count: int = Field(..., description="Total number of messages") - - -class ThreadListResponse(BaseModel): - """Response model for listing all threads""" - - threads: List[ThreadSummary] = Field(..., description="List of thread summaries") - - -class ThreadDetail(BaseModel): - """Detailed view of a single thread""" - - thread_id: str = Field(..., description="Thread ID") - created_at: float = Field(..., description="Thread creation timestamp") - messages: List[MessageItem] = Field( - ..., description="All messages in chronological order" - ) - - -class DeleteResponse(BaseModel): - """Response model for delete operations""" - - message: str = Field(..., description="Confirmation message") - thread_id: str = Field(..., description="Deleted thread ID") - # --- Actual endpoints for chatbot --- +@router.post("/message/stream") +@limiter.limit("30/minute") +async def post_chat_message_stream( + *, + request: Request, + message_request: ChatMessageRequest, + currentUser: User = Depends(getCurrentUser), +) -> StreamingResponse: + """ + Post a message to a chat thread with streaming progress updates. + Creates a new thread if thread_id is not provided. + + Returns Server-Sent Events (SSE) stream with status updates and final response. + """ + try: + # Generate or use existing thread_id + thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}" + + logger.info( + f"User {currentUser.id} posted streaming message to thread {thread_id}" + ) + + return StreamingResponse( + chat_service.post_message_stream( + thread_id=thread_id, + message=message_request.message, + user=currentUser, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + + except Exception as e: + logger.error(f"Error posting chat message: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to post message: {str(e)}", + ) + + @router.post("/message", response_model=ChatMessageResponse) @limiter.limit("30/minute") async def post_chat_message( @@ -88,35 +83,31 @@ async def post_chat_message( currentUser: User = Depends(getCurrentUser), ) -> ChatMessageResponse: """ - Post a message to a chat thread and get assistant response. + Post a message to a chat thread and get assistant response (non-streaming). Creates a new thread if thread_id is not provided. - This endpoint will later be connected to LangGraph's checkpointer. + For streaming updates, use the /message/stream endpoint instead. """ try: # Generate or use existing thread_id thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}" - # Get current timestamp - current_time = datetime.now().timestamp() - - # Create dummy message history - # In production, this will fetch from LangGraph's checkpointer - messages = [ - MessageItem( - role="user", content=message_request.message, timestamp=current_time - ), - MessageItem( - role="assistant", - content=f"Echo: {message_request.message} (This is a dummy response. LangGraph integration pending.)", - timestamp=current_time + 0.5, - ), - ] - logger.info(f"User {currentUser.id} posted message to thread {thread_id}") - return ChatMessageResponse(thread_id=thread_id, messages=messages) + response = await chat_service.post_message( + thread_id=thread_id, + message=message_request.message, + user=currentUser, + ) + return response + + except ValueError as e: + logger.error(f"Permission error: {str(e)}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e), + ) except Exception as e: logger.error(f"Error posting chat message: {str(e)}") raise HTTPException( diff --git a/requirements.txt b/requirements.txt index 28c8bb99..2378ca97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ fastapi==0.104.1 websockets==12.0 uvicorn==0.23.2 python-multipart==0.0.6 -httpx==0.25.0 +httpx>=0.25.2 pydantic>=2.0.0 # Upgraded to v2 for LangChain compatibility email-validator==2.0.0 # Required by Pydantic for email validation slowapi==0.1.8 # For rate limiting @@ -113,3 +113,7 @@ psycopg2-binary==2.9.9 langchain==0.3.27 langgraph==0.6.8 langchain-core==0.3.77 +langchain-anthropic==0.3.1 # For Claude models +psycopg[binary]==3.2.1 # For PostgreSQL async support (LangGraph checkpointer) +psycopg-pool==3.2.1 # Connection pooling for PostgreSQL +langgraph-checkpoint-postgres==2.0.24