feat: chatbot w/ streaming basics
This commit is contained in:
parent
8707203ac2
commit
4bfeded9d0
12 changed files with 1315 additions and 159 deletions
17
app.py
17
app.py
|
|
@ -240,9 +240,26 @@ instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
|
|||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Application is starting up")
|
||||
|
||||
# Initialize LangGraph checkpointer
|
||||
from modules.features.chatBot.utils.checkpointer import (
|
||||
initialize_checkpointer,
|
||||
close_checkpointer,
|
||||
)
|
||||
|
||||
try:
|
||||
await initialize_checkpointer()
|
||||
logger.info("LangGraph checkpointer initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LangGraph checkpointer: {str(e)}")
|
||||
# Continue startup even if checkpointer fails to initialize
|
||||
|
||||
eventManager.start()
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
eventManager.stop()
|
||||
await close_checkpointer()
|
||||
logger.info("Application has been shut down")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,18 @@ import uuid
|
|||
|
||||
|
||||
class ChatStat(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
workflowId: Optional[str] = Field(None, description="Foreign key to workflow (for workflow stats)")
|
||||
messageId: Optional[str] = Field(None, description="Foreign key to message (for message stats)")
|
||||
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
workflowId: Optional[str] = Field(
|
||||
None, description="Foreign key to workflow (for workflow stats)"
|
||||
)
|
||||
messageId: Optional[str] = Field(
|
||||
None, description="Foreign key to message (for message stats)"
|
||||
)
|
||||
processingTime: Optional[float] = Field(
|
||||
None, description="Processing time in seconds"
|
||||
)
|
||||
tokenCount: Optional[int] = Field(None, description="Number of tokens processed")
|
||||
bytesSent: Optional[int] = Field(None, description="Number of bytes sent")
|
||||
bytesReceived: Optional[int] = Field(None, description="Number of bytes received")
|
||||
|
|
@ -37,14 +45,23 @@ register_model_labels(
|
|||
|
||||
|
||||
class ChatLog(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
workflowId: str = Field(description="Foreign key to workflow")
|
||||
message: str = Field(description="Log message")
|
||||
type: str = Field(description="Log type (info, warning, error, etc.)")
|
||||
timestamp: float = Field(default_factory=get_utc_timestamp, description="When the log entry was created (UTC timestamp in seconds)")
|
||||
timestamp: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="When the log entry was created (UTC timestamp in seconds)",
|
||||
)
|
||||
status: Optional[str] = Field(None, description="Status of the log entry")
|
||||
progress: Optional[float] = Field(None, description="Progress indicator (0.0 to 1.0)")
|
||||
performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics")
|
||||
progress: Optional[float] = Field(
|
||||
None, description="Progress indicator (0.0 to 1.0)"
|
||||
)
|
||||
performance: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Performance metrics"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -64,7 +81,9 @@ register_model_labels(
|
|||
|
||||
|
||||
class ChatDocument(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
messageId: str = Field(description="Foreign key to message")
|
||||
fileId: str = Field(description="Foreign key to file")
|
||||
fileName: str = Field(description="Name of the file")
|
||||
|
|
@ -73,7 +92,9 @@ class ChatDocument(BaseModel, ModelMixin):
|
|||
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
|
||||
taskNumber: Optional[int] = Field(None, description="Task number within round")
|
||||
actionNumber: Optional[int] = Field(None, description="Action number within task")
|
||||
actionId: Optional[str] = Field(None, description="ID of the action that created this document")
|
||||
actionId: Optional[str] = Field(
|
||||
None, description="ID of the action that created this document"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -96,13 +117,19 @@ register_model_labels(
|
|||
|
||||
class ContentMetadata(BaseModel, ModelMixin):
|
||||
size: int = Field(description="Content size in bytes")
|
||||
pages: Optional[int] = Field(None, description="Number of pages for multi-page content")
|
||||
pages: Optional[int] = Field(
|
||||
None, description="Number of pages for multi-page content"
|
||||
)
|
||||
error: Optional[str] = Field(None, description="Processing error if any")
|
||||
width: Optional[int] = Field(None, description="Width in pixels for images/videos")
|
||||
height: Optional[int] = Field(None, description="Height in pixels for images/videos")
|
||||
height: Optional[int] = Field(
|
||||
None, description="Height in pixels for images/videos"
|
||||
)
|
||||
colorMode: Optional[str] = Field(None, description="Color mode")
|
||||
fps: Optional[float] = Field(None, description="Frames per second for videos")
|
||||
durationSec: Optional[float] = Field(None, description="Duration in seconds for media")
|
||||
durationSec: Optional[float] = Field(
|
||||
None, description="Duration in seconds for media"
|
||||
)
|
||||
mimeType: str = Field(description="MIME type of the content")
|
||||
base64Encoded: bool = Field(description="Whether the data is base64 encoded")
|
||||
|
||||
|
|
@ -144,7 +171,9 @@ register_model_labels(
|
|||
|
||||
class ExtractedContent(BaseModel, ModelMixin):
|
||||
id: str = Field(description="Reference to source ChatDocument")
|
||||
contents: List[ContentItem] = Field(default_factory=list, description="List of content items")
|
||||
contents: List[ContentItem] = Field(
|
||||
default_factory=list, description="List of content items"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -156,27 +185,53 @@ register_model_labels(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||
)
|
||||
workflowId: str = Field(description="Foreign key to workflow")
|
||||
parentMessageId: Optional[str] = Field(None, description="Parent message ID for threading")
|
||||
documents: List[ChatDocument] = Field(default_factory=list, description="Associated documents")
|
||||
documentsLabel: Optional[str] = Field(None, description="Label for the set of documents")
|
||||
parentMessageId: Optional[str] = Field(
|
||||
None, description="Parent message ID for threading"
|
||||
)
|
||||
documents: List[ChatDocument] = Field(
|
||||
default_factory=list, description="Associated documents"
|
||||
)
|
||||
documentsLabel: Optional[str] = Field(
|
||||
None, description="Label for the set of documents"
|
||||
)
|
||||
message: Optional[str] = Field(None, description="Message content")
|
||||
role: str = Field(description="Role of the message sender")
|
||||
status: str = Field(description="Status of the message (first, step, last)")
|
||||
sequenceNr: int = Field(description="Sequence number of the message (set automatically)")
|
||||
publishedAt: float = Field(default_factory=get_utc_timestamp, description="When the message was published (UTC timestamp in seconds)")
|
||||
sequenceNr: int = Field(
|
||||
description="Sequence number of the message (set automatically)"
|
||||
)
|
||||
publishedAt: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="When the message was published (UTC timestamp in seconds)",
|
||||
)
|
||||
stats: Optional[ChatStat] = Field(None, description="Statistics for this message")
|
||||
success: Optional[bool] = Field(None, description="Whether the message processing was successful")
|
||||
actionId: Optional[str] = Field(None, description="ID of the action that produced this message")
|
||||
actionMethod: Optional[str] = Field(None, description="Method of the action that produced this message")
|
||||
actionName: Optional[str] = Field(None, description="Name of the action that produced this message")
|
||||
success: Optional[bool] = Field(
|
||||
None, description="Whether the message processing was successful"
|
||||
)
|
||||
actionId: Optional[str] = Field(
|
||||
None, description="ID of the action that produced this message"
|
||||
)
|
||||
actionMethod: Optional[str] = Field(
|
||||
None, description="Method of the action that produced this message"
|
||||
)
|
||||
actionName: Optional[str] = Field(
|
||||
None, description="Name of the action that produced this message"
|
||||
)
|
||||
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
|
||||
taskNumber: Optional[int] = Field(None, description="Task number within round")
|
||||
actionNumber: Optional[int] = Field(None, description="Action number within task")
|
||||
taskProgress: Optional[str] = Field(None, description="Task progress status: pending, running, success, fail, retry")
|
||||
actionProgress: Optional[str] = Field(None, description="Action progress status: pending, running, success, fail")
|
||||
taskProgress: Optional[str] = Field(
|
||||
None, description="Task progress status: pending, running, success, fail, retry"
|
||||
)
|
||||
actionProgress: Optional[str] = Field(
|
||||
None, description="Action progress status: pending, running, success, fail"
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -208,31 +263,135 @@ register_model_labels(
|
|||
|
||||
|
||||
class ChatWorkflow(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
mandateId: str = Field(description="ID of the mandate this workflow belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
status: str = Field(description="Current status of the workflow", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Primary key",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
mandateId: str = Field(
|
||||
description="ID of the mandate this workflow belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
status: str = Field(
|
||||
description="Current status of the workflow",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False,
|
||||
frontend_options=[
|
||||
{"value": "running", "label": {"en": "Running", "fr": "En cours"}},
|
||||
{"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}},
|
||||
{"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}},
|
||||
{"value": "error", "label": {"en": "Error", "fr": "Erreur"}},
|
||||
])
|
||||
name: Optional[str] = Field(None, description="Name of the workflow", frontend_type="text", frontend_readonly=False, frontend_required=True)
|
||||
currentRound: int = Field(description="Current round number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
currentTask: int = Field(default=0, description="Current task number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
currentAction: int = Field(default=0, description="Current action number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
totalTasks: int = Field(default=0, description="Total number of tasks in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
totalActions: int = Field(default=0, description="Total number of actions in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
lastActivity: float = Field(default_factory=get_utc_timestamp, description="Timestamp of last activity (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
startedAt: float = Field(default_factory=get_utc_timestamp, description="When the workflow started (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
logs: List[ChatLog] = Field(default_factory=list, description="Workflow logs", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
messages: List[ChatMessage] = Field(default_factory=list, description="Messages in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
stats: Optional[ChatStat] = Field(None, description="Workflow statistics", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
tasks: list = Field(default_factory=list, description="List of tasks in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
workflowMode: str = Field(default="Actionplan", description="Workflow mode selector", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
|
||||
{"value": "Actionplan", "label": {"en": "Action Plan", "fr": "Plan d'actions"}},
|
||||
],
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of the workflow",
|
||||
frontend_type="text",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True,
|
||||
)
|
||||
currentRound: int = Field(
|
||||
description="Current round number",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
currentTask: int = Field(
|
||||
default=0,
|
||||
description="Current task number",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
currentAction: int = Field(
|
||||
default=0,
|
||||
description="Current action number",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
totalTasks: int = Field(
|
||||
default=0,
|
||||
description="Total number of tasks in the workflow",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
totalActions: int = Field(
|
||||
default=0,
|
||||
description="Total number of actions in the workflow",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
lastActivity: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="Timestamp of last activity (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
startedAt: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="When the workflow started (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
logs: List[ChatLog] = Field(
|
||||
default_factory=list,
|
||||
description="Workflow logs",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
messages: List[ChatMessage] = Field(
|
||||
default_factory=list,
|
||||
description="Messages in the workflow",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
stats: Optional[ChatStat] = Field(
|
||||
None,
|
||||
description="Workflow statistics",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
tasks: list = Field(
|
||||
default_factory=list,
|
||||
description="List of tasks in the workflow",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
workflowMode: str = Field(
|
||||
default="Actionplan",
|
||||
description="Workflow mode selector",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False,
|
||||
frontend_options=[
|
||||
{
|
||||
"value": "Actionplan",
|
||||
"label": {"en": "Action Plan", "fr": "Plan d'actions"},
|
||||
},
|
||||
{"value": "React", "label": {"en": "React", "fr": "Réactif"}},
|
||||
])
|
||||
maxSteps: int = Field(default=5, description="Maximum number of iterations in react mode", frontend_type="integer", frontend_readonly=False, frontend_required=False)
|
||||
],
|
||||
)
|
||||
maxSteps: int = Field(
|
||||
default=5,
|
||||
description="Maximum number of iterations in react mode",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False,
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -278,7 +437,10 @@ register_model_labels(
|
|||
"completed_tasks": {"en": "Completed Tasks", "fr": "Tâches terminées"},
|
||||
"total_tasks": {"en": "Total Tasks", "fr": "Total des tâches"},
|
||||
"execution_time": {"en": "Execution Time", "fr": "Temps d'exécution"},
|
||||
"final_results_count": {"en": "Final Results Count", "fr": "Nombre de résultats finaux"},
|
||||
"final_results_count": {
|
||||
"en": "Final Results Count",
|
||||
"fr": "Nombre de résultats finaux",
|
||||
},
|
||||
"error": {"en": "Error", "fr": "Erreur"},
|
||||
"phase": {"en": "Phase", "fr": "Phase"},
|
||||
},
|
||||
|
|
@ -300,5 +462,3 @@ register_model_labels(
|
|||
"userLanguage": {"en": "User Language", "fr": "Langue de l'utilisateur"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
130
modules/datamodels/datamodelChatbot.py
Normal file
130
modules/datamodels/datamodelChatbot.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
"""Chatbot API models for request/response handling."""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||
|
||||
|
||||
# Chatbot API Models
|
||||
class MessageItem(BaseModel, ModelMixin):
|
||||
"""Individual message in a thread"""
|
||||
|
||||
role: str = Field(..., description="Message role (user or assistant)")
|
||||
content: str = Field(..., description="Message content")
|
||||
timestamp: float = Field(..., description="Message timestamp (Unix timestamp)")
|
||||
|
||||
|
||||
class ChatMessageRequest(BaseModel, ModelMixin):
|
||||
"""Request model for posting a chat message"""
|
||||
|
||||
thread_id: Optional[str] = Field(
|
||||
None, description="Thread ID (creates new thread if not provided)"
|
||||
)
|
||||
message: str = Field(..., description="User message content")
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel, ModelMixin):
|
||||
"""Response model for posting a chat message"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
messages: List[MessageItem] = Field(..., description="All messages in thread")
|
||||
|
||||
|
||||
class ThreadSummary(BaseModel, ModelMixin):
|
||||
"""Summary of a chat thread for list view"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
created_at: float = Field(..., description="Thread creation timestamp")
|
||||
last_message: str = Field(..., description="Last message content")
|
||||
message_count: int = Field(..., description="Total number of messages")
|
||||
|
||||
|
||||
class ThreadListResponse(BaseModel, ModelMixin):
|
||||
"""Response model for listing all threads"""
|
||||
|
||||
threads: List[ThreadSummary] = Field(..., description="List of thread summaries")
|
||||
|
||||
|
||||
class ThreadDetail(BaseModel, ModelMixin):
|
||||
"""Detailed view of a single thread"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
created_at: float = Field(..., description="Thread creation timestamp")
|
||||
messages: List[MessageItem] = Field(
|
||||
..., description="All messages in chronological order"
|
||||
)
|
||||
|
||||
|
||||
class DeleteResponse(BaseModel, ModelMixin):
|
||||
"""Response model for delete operations"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
thread_id: str = Field(..., description="Deleted thread ID")
|
||||
|
||||
|
||||
# Register model labels for internationalization
|
||||
register_model_labels(
|
||||
"MessageItem",
|
||||
{"en": "Message Item", "fr": "Élément de message"},
|
||||
{
|
||||
"role": {"en": "Role", "fr": "Rôle"},
|
||||
"content": {"en": "Content", "fr": "Contenu"},
|
||||
"timestamp": {"en": "Timestamp", "fr": "Horodatage"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ChatMessageRequest",
|
||||
{"en": "Chat Message Request", "fr": "Demande de message de chat"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"message": {"en": "Message", "fr": "Message"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ChatMessageResponse",
|
||||
{"en": "Chat Message Response", "fr": "Réponse du message de chat"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"messages": {"en": "Messages", "fr": "Messages"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ThreadSummary",
|
||||
{"en": "Thread Summary", "fr": "Résumé du fil"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"created_at": {"en": "Created At", "fr": "Créé le"},
|
||||
"last_message": {"en": "Last Message", "fr": "Dernier message"},
|
||||
"message_count": {"en": "Message Count", "fr": "Nombre de messages"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ThreadListResponse",
|
||||
{"en": "Thread List Response", "fr": "Réponse de liste de fils"},
|
||||
{
|
||||
"threads": {"en": "Threads", "fr": "Fils"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"ThreadDetail",
|
||||
{"en": "Thread Detail", "fr": "Détail du fil"},
|
||||
{
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
"created_at": {"en": "Created At", "fr": "Créé le"},
|
||||
"messages": {"en": "Messages", "fr": "Messages"},
|
||||
},
|
||||
)
|
||||
|
||||
register_model_labels(
|
||||
"DeleteResponse",
|
||||
{"en": "Delete Response", "fr": "Réponse de suppression"},
|
||||
{
|
||||
"message": {"en": "Message", "fr": "Message"},
|
||||
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||
},
|
||||
)
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
"""Tool for sending streaming status updates to users."""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
def send_streaming_message(message: str) -> str:
|
||||
"""Send a streaming message to the user to provide updates during processing.
|
||||
|
||||
Use this tool to send short status updates to the user while you are working
|
||||
on their request. This helps keep the user informed about what you are doing.
|
||||
|
||||
Args:
|
||||
message: A short message describing what you are currently doing.
|
||||
Examples: "Searching database for relevant information..."
|
||||
"Analyzing search results..."
|
||||
"Processing your request..."
|
||||
|
||||
Returns:
|
||||
A confirmation that the message was sent.
|
||||
"""
|
||||
# This tool doesn't actually do anything - it's just for the AI to signal
|
||||
# what it's doing to the frontend via the tool call mechanism
|
||||
return f"Status update sent: {message}"
|
||||
1
modules/features/chatBot/domain/__init__.py
Normal file
1
modules/features/chatBot/domain/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Domain logic for chatbot functionality."""
|
||||
301
modules/features/chatBot/domain/chatbot.py
Normal file
301
modules/features/chatBot/domain/chatbot.py
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
"""Chatbot domain logic with LangGraph integration."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, AsyncIterator, Any
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
trim_messages,
|
||||
)
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper
|
||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatState(BaseModel):
|
||||
"""Represents the state of a chat session."""
|
||||
|
||||
messages: Annotated[list[BaseMessage], add_messages]
|
||||
|
||||
|
||||
def get_langchain_model(*, model_name: str) -> ChatAnthropic:
|
||||
"""Map permission model names to LangChain ChatAnthropic models.
|
||||
|
||||
Args:
|
||||
model_name: The model name from permissions (e.g., "claude_4_5")
|
||||
|
||||
Returns:
|
||||
Configured ChatAnthropic instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the model name is not supported
|
||||
"""
|
||||
# Model name mapping
|
||||
model_mapping = {
|
||||
"claude_4_5": "claude-4-5-sonnet",
|
||||
# Add more mappings as needed
|
||||
}
|
||||
|
||||
anthropic_model = model_mapping.get(model_name)
|
||||
if not anthropic_model:
|
||||
logger.warning(
|
||||
f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet"
|
||||
)
|
||||
anthropic_model = "claude-4-5-sonnet"
|
||||
|
||||
return ChatAnthropic(
|
||||
model=anthropic_model,
|
||||
api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"),
|
||||
temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)),
|
||||
max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chatbot:
|
||||
"""Represents a chatbot with LangGraph integration."""
|
||||
|
||||
model: Any
|
||||
memory: Any
|
||||
app: Any = None
|
||||
system_prompt: str = "You are a helpful assistant."
|
||||
context_window_size: int = 100000
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
*,
|
||||
model: Any,
|
||||
memory: Any,
|
||||
system_prompt: str,
|
||||
tools: list,
|
||||
context_window_size: int = 100000,
|
||||
) -> "Chatbot":
|
||||
"""Factory method to create and configure a Chatbot instance.
|
||||
|
||||
Args:
|
||||
model: The chat model to use.
|
||||
memory: The chat memory checkpointer to use.
|
||||
system_prompt: The system prompt to initialize the chatbot.
|
||||
tools: List of LangChain tools the chatbot can use.
|
||||
context_window_size: Maximum tokens for context window.
|
||||
|
||||
Returns:
|
||||
A configured Chatbot instance.
|
||||
"""
|
||||
instance = cls(
|
||||
model=model,
|
||||
memory=memory,
|
||||
system_prompt=system_prompt,
|
||||
context_window_size=context_window_size,
|
||||
)
|
||||
instance.app = instance._build_app(memory=memory, tools=tools)
|
||||
return instance
|
||||
|
||||
def _build_app(
|
||||
self, *, memory: Any, tools: list
|
||||
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
|
||||
"""Builds the chatbot application workflow using LangGraph.
|
||||
|
||||
Args:
|
||||
memory: The chat memory checkpointer to use.
|
||||
tools: The list of tools the chatbot can use.
|
||||
|
||||
Returns:
|
||||
A compiled state graph representing the chatbot application.
|
||||
"""
|
||||
llm_with_tools = self.model.bind_tools(tools=tools)
|
||||
|
||||
def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Selects a window of messages that fit within the context window size.
|
||||
|
||||
Args:
|
||||
msgs: The list of messages to select from.
|
||||
|
||||
Returns:
|
||||
A list of messages that fit within the context window size.
|
||||
"""
|
||||
|
||||
def approx_counter(items: list[BaseMessage]) -> int:
|
||||
"""Approximate token counter for messages.
|
||||
|
||||
Args:
|
||||
items: List of messages to count tokens for.
|
||||
|
||||
Returns:
|
||||
Approximate number of tokens in the messages.
|
||||
"""
|
||||
return sum(len(getattr(m, "content", "") or "") for m in items)
|
||||
|
||||
return trim_messages(
|
||||
msgs,
|
||||
strategy="last",
|
||||
token_counter=approx_counter,
|
||||
max_tokens=self.context_window_size,
|
||||
start_on="human",
|
||||
end_on=("human", "tool"),
|
||||
include_system=True,
|
||||
)
|
||||
|
||||
def agent_node(state: ChatState) -> dict:
|
||||
"""Agent node for the chatbot workflow.
|
||||
|
||||
Args:
|
||||
state: The current chat state.
|
||||
|
||||
Returns:
|
||||
The updated chat state after processing.
|
||||
"""
|
||||
# Select the message window to fit in context (trim if needed)
|
||||
window = select_window(state.messages)
|
||||
|
||||
# Ensure the system prompt is present at the start
|
||||
if not window or not isinstance(window[0], SystemMessage):
|
||||
window = [SystemMessage(content=self.system_prompt)] + window
|
||||
|
||||
# Call the LLM with tools
|
||||
response = llm_with_tools.invoke(window)
|
||||
|
||||
# Return the new state
|
||||
return {"messages": [response]}
|
||||
|
||||
def should_continue(state: ChatState) -> str:
|
||||
"""Determines whether to continue the workflow or end it.
|
||||
|
||||
This conditional edge is called after the agent node to decide
|
||||
whether to continue to the tools node (if the last message contains
|
||||
tool calls) or to end the workflow (if no tool calls are present).
|
||||
|
||||
Args:
|
||||
state: The current chat state.
|
||||
|
||||
Returns:
|
||||
The next node to transition to ("tools" or END).
|
||||
"""
|
||||
# Get the last message
|
||||
last_message = state.messages[-1]
|
||||
|
||||
# Check if the last message contains tool calls
|
||||
# If so, continue to the tools node; otherwise, end the workflow
|
||||
return "tools" if getattr(last_message, "tool_calls", None) else END
|
||||
|
||||
# Compose the workflow
|
||||
workflow = StateGraph(ChatState)
|
||||
workflow.add_node("agent", agent_node)
|
||||
workflow.add_node("tools", ToolNode(tools=tools))
|
||||
workflow.add_edge(START, "agent")
|
||||
workflow.add_conditional_edges("agent", should_continue)
|
||||
workflow.add_edge("tools", "agent")
|
||||
return workflow.compile(checkpointer=memory)
|
||||
|
||||
async def chat(
|
||||
self, *, message: str, chat_id: str = "default"
|
||||
) -> list[BaseMessage]:
|
||||
"""Processes a chat message and returns the chat history.
|
||||
|
||||
Args:
|
||||
message: The user message to process.
|
||||
chat_id: The chat thread ID.
|
||||
|
||||
Returns:
|
||||
The list of messages in the chat history.
|
||||
"""
|
||||
# Set the right thread ID for memory
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
# Single-turn chat (non-streaming)
|
||||
result = await self.app.ainvoke(
|
||||
{"messages": [HumanMessage(content=message)]}, config=config
|
||||
)
|
||||
|
||||
# Extract and return the messages from the result
|
||||
return result["messages"]
|
||||
|
||||
async def stream_events(
|
||||
self, *, message: str, chat_id: str = "default"
|
||||
) -> AsyncIterator[dict]:
|
||||
"""Stream UI-focused events using astream_events v2.
|
||||
|
||||
Args:
|
||||
message: The user message to process.
|
||||
chat_id: Logical thread identifier; forwarded in the runnable config so
|
||||
memory and tools are scoped per thread.
|
||||
|
||||
Yields:
|
||||
dict: One of:
|
||||
- ``{"type": "status", "label": str}`` for short progress updates.
|
||||
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
|
||||
where ``chat_history`` only includes ``user``/``assistant`` roles.
|
||||
- ``{"type": "error", "message": str}`` if an exception occurs.
|
||||
"""
|
||||
# Thread-aware config for LangGraph/LangChain
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
|
||||
def _is_root(ev: dict) -> bool:
|
||||
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
||||
return not ev.get("parent_ids")
|
||||
|
||||
try:
|
||||
async for event in self.app.astream_events(
|
||||
{"messages": [HumanMessage(content=message)]},
|
||||
config=config,
|
||||
version="v2",
|
||||
):
|
||||
etype = event.get("event")
|
||||
ename = event.get("name") or ""
|
||||
edata = event.get("data") or {}
|
||||
|
||||
# Stream human-readable progress via the special send_streaming_message tool
|
||||
if etype == "on_tool_start" and ename == "send_streaming_message":
|
||||
tool_in = edata.get("input") or {}
|
||||
msg = tool_in.get("message")
|
||||
if isinstance(msg, str) and msg.strip():
|
||||
yield {"type": "status", "label": msg.strip()}
|
||||
continue
|
||||
|
||||
# Emit the final payload when the root run finishes
|
||||
if etype == "on_chain_end" and _is_root(event):
|
||||
output_obj = edata.get("output")
|
||||
|
||||
# Extract message list from the graph's final output
|
||||
final_msgs = ChatStreamingHelper.extract_messages_from_output(
|
||||
output_obj=output_obj
|
||||
)
|
||||
|
||||
# Normalize for the frontend (only user/assistant with text content)
|
||||
chat_history_payload: list[dict] = []
|
||||
for m in final_msgs:
|
||||
if isinstance(m, BaseMessage):
|
||||
d = ChatStreamingHelper.message_to_dict(msg=m)
|
||||
elif isinstance(m, dict):
|
||||
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
|
||||
else:
|
||||
continue
|
||||
if d.get("role") in ("user", "assistant") and d.get("content"):
|
||||
chat_history_payload.append(d)
|
||||
|
||||
yield {
|
||||
"type": "final",
|
||||
"response": {
|
||||
"thread": chat_id,
|
||||
"chat_history": chat_history_payload,
|
||||
},
|
||||
}
|
||||
return
|
||||
|
||||
except Exception as exc:
|
||||
# Emit a single error envelope and end the stream
|
||||
logger.error(f"Error in stream_events: {str(exc)}", exc_info=True)
|
||||
yield {"type": "error", "message": f"Error processing request: {exc}"}
|
||||
239
modules/features/chatBot/domain/streaming_helper.py
Normal file
239
modules/features/chatBot/domain/streaming_helper.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""Streaming helper utilities for chat message processing and normalization."""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
Role = Literal["user", "assistant", "system", "tool"]
|
||||
|
||||
|
||||
class ChatStreamingHelper:
|
||||
"""Pure helper methods for streaming and message normalization.
|
||||
|
||||
This class provides static utility methods for converting between different
|
||||
message formats, extracting content, and normalizing message structures
|
||||
for streaming chat applications.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def role_from_message(*, msg: BaseMessage) -> Role:
|
||||
"""Extract the role from a BaseMessage instance.
|
||||
|
||||
Args:
|
||||
msg: The BaseMessage instance to extract the role from.
|
||||
|
||||
Returns:
|
||||
The role as a string literal: "user", "assistant", "system", or "tool".
|
||||
Defaults to "assistant" if the message type is not recognized.
|
||||
|
||||
Examples:
|
||||
>>> from langchain_core.messages import HumanMessage
|
||||
>>> msg = HumanMessage(content="Hello")
|
||||
>>> ChatStreamingHelper.role_from_message(msg=msg)
|
||||
'user'
|
||||
"""
|
||||
if isinstance(msg, HumanMessage):
|
||||
return "user"
|
||||
if isinstance(msg, AIMessage):
|
||||
return "assistant"
|
||||
if isinstance(msg, SystemMessage):
|
||||
return "system"
|
||||
if isinstance(msg, ToolMessage):
|
||||
return "tool"
|
||||
return getattr(msg, "role", "assistant")
|
||||
|
||||
@staticmethod
|
||||
def flatten_content(*, content: Any) -> str:
|
||||
"""Convert complex content structures to plain text.
|
||||
|
||||
This method handles various content formats including strings, lists of
|
||||
content parts, and dictionaries with text fields. It's designed to
|
||||
normalize content from different message sources into a consistent
|
||||
plain text format.
|
||||
|
||||
Args:
|
||||
content: The content to flatten. Can be:
|
||||
- str: Returned as-is after stripping whitespace
|
||||
- list: Each item processed and joined with newlines
|
||||
- dict: Text extracted from "text" or "content" fields
|
||||
- None: Returns empty string
|
||||
- Any other type: Converted to string
|
||||
|
||||
Returns:
|
||||
The flattened content as a plain text string with whitespace stripped.
|
||||
|
||||
Examples:
|
||||
>>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]
|
||||
>>> ChatStreamingHelper.flatten_content(content=content)
|
||||
'Hello
|
||||
nworld'
|
||||
|
||||
>>> content = {"text": "Simple message"}
|
||||
>>> ChatStreamingHelper.flatten_content(content=content)
|
||||
'Simple message'
|
||||
"""
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts: List[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
if "text" in part and isinstance(part["text"], str):
|
||||
parts.append(part["text"])
|
||||
elif part.get("type") == "text" and isinstance(
|
||||
part.get("text"), str
|
||||
):
|
||||
parts.append(part["text"])
|
||||
elif "content" in part and isinstance(part["content"], str):
|
||||
parts.append(part["content"])
|
||||
else:
|
||||
# Fallback for unknown dictionary structures
|
||||
val = part.get("value")
|
||||
if isinstance(val, str):
|
||||
parts.append(val)
|
||||
else:
|
||||
parts.append(str(part))
|
||||
return "\n".join(p.strip() for p in parts if p is not None)
|
||||
if isinstance(content, dict):
|
||||
if "text" in content and isinstance(content["text"], str):
|
||||
return content["text"].strip()
|
||||
if "content" in content and isinstance(content["content"], str):
|
||||
return content["content"].strip()
|
||||
return str(content).strip()
|
||||
|
||||
@staticmethod
|
||||
def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]:
|
||||
"""Convert a BaseMessage instance to a dictionary for streaming output.
|
||||
|
||||
This method normalizes BaseMessage instances into a consistent dictionary
|
||||
format suitable for JSON serialization and streaming to clients.
|
||||
|
||||
Args:
|
||||
msg: The BaseMessage instance to convert.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- "role": The message role (user, assistant, system, tool)
|
||||
- "content": The flattened message content as plain text
|
||||
- "tool_calls": Tool calls if present (optional)
|
||||
- "name": Message name if present (optional)
|
||||
|
||||
Examples:
|
||||
>>> from langchain_core.messages import HumanMessage
|
||||
>>> msg = HumanMessage(content="Hello there")
|
||||
>>> result = ChatStreamingHelper.message_to_dict(msg=msg)
|
||||
>>> result["role"]
|
||||
'user'
|
||||
>>> result["content"]
|
||||
'Hello there'
|
||||
"""
|
||||
payload: Dict[str, Any] = {
|
||||
"role": ChatStreamingHelper.role_from_message(msg=msg),
|
||||
"content": ChatStreamingHelper.flatten_content(
|
||||
content=getattr(msg, "content", "")
|
||||
),
|
||||
}
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if tool_calls:
|
||||
payload["tool_calls"] = tool_calls
|
||||
name = getattr(msg, "name", None)
|
||||
if name:
|
||||
payload["name"] = name
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert a dictionary-shaped message to a normalized dictionary.
|
||||
|
||||
This method handles messages that come from serialized state and are
|
||||
represented as dictionaries rather than BaseMessage instances. It
|
||||
normalizes various dictionary formats into a consistent structure.
|
||||
|
||||
Args:
|
||||
obj: The dictionary-shaped message to convert. Expected to contain
|
||||
fields like "role", "type", "content", "text", etc.
|
||||
|
||||
Returns:
|
||||
A normalized dictionary containing:
|
||||
- "role": The message role (user, assistant, system, tool)
|
||||
- "content": The flattened message content as plain text
|
||||
- "tool_calls": Tool calls if present (optional)
|
||||
- "name": Message name if present (optional)
|
||||
|
||||
Examples:
|
||||
>>> obj = {"type": "human", "content": "Hello"}
|
||||
>>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj)
|
||||
>>> result["role"]
|
||||
'user'
|
||||
>>> result["content"]
|
||||
'Hello'
|
||||
"""
|
||||
role: Optional[str] = obj.get("role")
|
||||
if not role:
|
||||
# Handle alternative type field mappings
|
||||
typ = obj.get("type")
|
||||
if typ in ("human", "user"):
|
||||
role = "user"
|
||||
elif typ in ("ai", "assistant"):
|
||||
role = "assistant"
|
||||
elif typ in ("system",):
|
||||
role = "system"
|
||||
elif typ in ("tool", "function"):
|
||||
role = "tool"
|
||||
|
||||
content = obj.get("content")
|
||||
if content is None and "text" in obj:
|
||||
content = obj["text"]
|
||||
|
||||
out: Dict[str, Any] = {
|
||||
"role": role or "assistant",
|
||||
"content": ChatStreamingHelper.flatten_content(content=content),
|
||||
}
|
||||
if "tool_calls" in obj:
|
||||
out["tool_calls"] = obj["tool_calls"]
|
||||
if obj.get("name"):
|
||||
out["name"] = obj["name"]
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def extract_messages_from_output(*, output_obj: Any) -> List[Any]:
|
||||
"""Extract messages from LangGraph output objects.
|
||||
|
||||
This method handles various output formats from LangGraph execution,
|
||||
extracting the messages list from different possible structures.
|
||||
|
||||
Args:
|
||||
output_obj: The output object from LangGraph execution. Can be:
|
||||
- An object with a "messages" attribute
|
||||
- A dictionary with a "messages" key
|
||||
- Any other object (returns empty list)
|
||||
|
||||
Returns:
|
||||
A list of extracted messages, or an empty list if no messages
|
||||
are found or if the output object is None.
|
||||
|
||||
Examples:
|
||||
>>> output = {"messages": [{"role": "user", "content": "Hello"}]}
|
||||
>>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output)
|
||||
>>> len(messages)
|
||||
1
|
||||
"""
|
||||
if output_obj is None:
|
||||
return []
|
||||
|
||||
# Try to parse dicts first
|
||||
if isinstance(output_obj, dict):
|
||||
msgs = output_obj.get("messages")
|
||||
return msgs if isinstance(msgs, list) else []
|
||||
|
||||
# Then try to get messages attribute
|
||||
msgs = getattr(output_obj, "messages", None)
|
||||
return msgs if isinstance(msgs, list) else []
|
||||
215
modules/features/chatBot/service.py
Normal file
215
modules/features/chatBot/service.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
"""Service layer for chatbot functionality."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncIterator, List
|
||||
|
||||
from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model
|
||||
from modules.features.chatBot.utils.checkpointer import get_checkpointer
|
||||
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||
from modules.features.chatBot.utils import permissions
|
||||
from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse
|
||||
from modules.datamodels.datamodelUam import User
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def post_message(
|
||||
*,
|
||||
thread_id: str,
|
||||
message: str,
|
||||
user: User,
|
||||
) -> ChatMessageResponse:
|
||||
"""Post a chat message to the chatbot and return the response.
|
||||
|
||||
Args:
|
||||
thread_id: The unique identifier for the chat thread.
|
||||
message: The content of the chat message.
|
||||
user: The current user.
|
||||
|
||||
Returns:
|
||||
The response containing the full chat message history and thread ID.
|
||||
"""
|
||||
logger.info(f"User {user.id} posted message to thread {thread_id}")
|
||||
|
||||
# Get user permissions
|
||||
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
|
||||
if not tool_ids:
|
||||
raise ValueError("User does not have permission to use any chatbot tools")
|
||||
|
||||
model_name = permissions.get_chatbot_model(user_id=user.id)
|
||||
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
||||
|
||||
# Get tools from registry
|
||||
registry = get_registry()
|
||||
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
||||
|
||||
# Get model and checkpointer
|
||||
model = get_langchain_model(model_name=model_name)
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
# Get context window size from config
|
||||
context_window_size = int(
|
||||
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
|
||||
)
|
||||
|
||||
# Create chatbot instance
|
||||
chatbot = await Chatbot.create(
|
||||
model=model,
|
||||
memory=checkpointer,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
context_window_size=context_window_size,
|
||||
)
|
||||
|
||||
# Send message to chatbot
|
||||
response = await chatbot.chat(message=message, chat_id=thread_id)
|
||||
|
||||
# Parse the response to the correct format
|
||||
messages = []
|
||||
for msg in response:
|
||||
# Determine the role of the message
|
||||
if isinstance(msg, HumanMessage):
|
||||
role = "user"
|
||||
elif isinstance(msg, AIMessage):
|
||||
role = "assistant"
|
||||
else:
|
||||
continue # Skip any other message types
|
||||
|
||||
# Skip messages that are structured content, such as tool calls
|
||||
if not isinstance(msg.content, str):
|
||||
continue
|
||||
|
||||
# Append message to chat history
|
||||
item = MessageItem(
|
||||
role=role,
|
||||
content=msg.content.strip(),
|
||||
timestamp=0.0, # TODO: Add proper timestamp handling
|
||||
)
|
||||
messages.append(item)
|
||||
|
||||
return ChatMessageResponse(thread_id=thread_id, messages=messages)
|
||||
|
||||
|
||||
async def post_message_stream(
|
||||
*,
|
||||
thread_id: str,
|
||||
message: str,
|
||||
user: User,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Post a chat message to the chatbot and stream progress updates (SSE).
|
||||
|
||||
Args:
|
||||
thread_id: The unique identifier for the chat thread.
|
||||
message: The content of the chat message.
|
||||
user: The current user.
|
||||
|
||||
Yields:
|
||||
Server-Sent Events formatted strings containing status updates and final response.
|
||||
"""
|
||||
logger.info(f"User {user.id} streaming message to thread {thread_id}")
|
||||
|
||||
try:
|
||||
# Get user permissions
|
||||
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
|
||||
if not tool_ids:
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "User does not have permission to use any chatbot tools",
|
||||
}
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
return
|
||||
|
||||
model_name = permissions.get_chatbot_model(user_id=user.id)
|
||||
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
||||
|
||||
# Get tools from registry
|
||||
registry = get_registry()
|
||||
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
||||
|
||||
# Get model and checkpointer
|
||||
model = get_langchain_model(model_name=model_name)
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
# Get context window size from config
|
||||
context_window_size = int(
|
||||
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
|
||||
)
|
||||
|
||||
# Create chatbot instance
|
||||
chatbot = await Chatbot.create(
|
||||
model=model,
|
||||
memory=checkpointer,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
context_window_size=context_window_size,
|
||||
)
|
||||
|
||||
# Stream events from chatbot
|
||||
async for event in chatbot.stream_events(message=message, chat_id=thread_id):
|
||||
etype = event.get("type")
|
||||
|
||||
# Forward status updates
|
||||
if etype == "status":
|
||||
yield f"data: {json.dumps({'type': 'status', 'label': event.get('label')})}\n\n"
|
||||
continue
|
||||
|
||||
# Forward final response
|
||||
if etype == "final":
|
||||
response_from_event = event.get("response") or {}
|
||||
|
||||
# Use the chat history from the final event (already normalized by stream_events)
|
||||
chat_history_payload = response_from_event.get("chat_history", [])
|
||||
if isinstance(chat_history_payload, list):
|
||||
# Convert to MessageItem format
|
||||
items: List[MessageItem] = []
|
||||
for it in chat_history_payload:
|
||||
role = it.get("role")
|
||||
content = it.get("content", "")
|
||||
if role in ("user", "assistant") and content:
|
||||
items.append(
|
||||
MessageItem(
|
||||
role=role,
|
||||
content=content,
|
||||
timestamp=0.0, # TODO: Add proper timestamp handling
|
||||
)
|
||||
)
|
||||
|
||||
response = ChatMessageResponse(thread_id=thread_id, messages=items)
|
||||
# Yield the final response and exit
|
||||
yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n"
|
||||
return
|
||||
else:
|
||||
# Unexpected payload format - log warning and return empty history
|
||||
logger.warning(
|
||||
f"Unexpected chat_history format in final event: {type(chat_history_payload)}"
|
||||
)
|
||||
response = ChatMessageResponse(thread_id=thread_id, messages=[])
|
||||
yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n"
|
||||
return
|
||||
|
||||
# Forward error events
|
||||
if etype == "error":
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming chat: {str(e)}", exc_info=True)
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "An error occurred while processing your request.",
|
||||
}
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
95
modules/features/chatBot/utils/checkpointer.py
Normal file
95
modules/features/chatBot/utils/checkpointer.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""PostgreSQL checkpointer utilities for LangGraph memory."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from langgraph.checkpoint.postgres import PostgresSaver
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global checkpointer instance
|
||||
_checkpointer_instance: Optional[PostgresSaver] = None
|
||||
_connection_pool: Optional[AsyncConnectionPool] = None
|
||||
|
||||
|
||||
async def initialize_checkpointer() -> None:
|
||||
"""Initialize the PostgreSQL checkpointer for LangGraph.
|
||||
|
||||
This should be called during application startup.
|
||||
Creates a connection pool and PostgresSaver instance.
|
||||
"""
|
||||
global _checkpointer_instance, _connection_pool
|
||||
|
||||
if _checkpointer_instance is not None:
|
||||
logger.info("Checkpointer already initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
# Get database configuration from environment
|
||||
host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost")
|
||||
database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat")
|
||||
user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev")
|
||||
password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET")
|
||||
port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432")
|
||||
|
||||
# Build connection string
|
||||
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}"
|
||||
|
||||
# Create async connection pool
|
||||
_connection_pool = AsyncConnectionPool(
|
||||
conninfo=connection_string,
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
)
|
||||
|
||||
# Initialize the connection pool
|
||||
await _connection_pool.open()
|
||||
|
||||
# Create PostgresSaver with the pool
|
||||
_checkpointer_instance = PostgresSaver(_connection_pool)
|
||||
|
||||
# Setup the checkpointer (creates tables if needed)
|
||||
await _checkpointer_instance.setup()
|
||||
|
||||
logger.info("PostgreSQL checkpointer initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def close_checkpointer() -> None:
|
||||
"""Close the checkpointer and connection pool.
|
||||
|
||||
This should be called during application shutdown.
|
||||
"""
|
||||
global _checkpointer_instance, _connection_pool
|
||||
|
||||
if _connection_pool is not None:
|
||||
try:
|
||||
await _connection_pool.close()
|
||||
logger.info("PostgreSQL checkpointer connection pool closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing checkpointer connection pool: {str(e)}")
|
||||
|
||||
_checkpointer_instance = None
|
||||
_connection_pool = None
|
||||
|
||||
|
||||
def get_checkpointer() -> PostgresSaver:
|
||||
"""Get the global PostgreSQL checkpointer instance.
|
||||
|
||||
Returns:
|
||||
The initialized PostgresSaver instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If checkpointer is not initialized
|
||||
"""
|
||||
if _checkpointer_instance is None:
|
||||
raise RuntimeError(
|
||||
"PostgreSQL checkpointer not initialized. "
|
||||
"Call initialize_checkpointer() during application startup."
|
||||
)
|
||||
return _checkpointer_instance
|
||||
|
|
@ -12,36 +12,15 @@ from modules.features.chatBot.utils.toolRegistry import get_registry
|
|||
# TODO: Replace these mock implementations with actual database queries
|
||||
|
||||
|
||||
def get_allowed_tools(*, user_id: str) -> list[str]:
|
||||
"""Get list of tool IDs that a user is allowed to use.
|
||||
|
||||
This is a mock implementation that returns all available tools
|
||||
regardless of user_id. In production, this will query the database
|
||||
for user-specific permissions.
|
||||
|
||||
Args:
|
||||
user_id: The unique identifier of the user
|
||||
|
||||
Returns:
|
||||
List of tool IDs (e.g., ["shared.tavily_search", "customer.query_althaus_database"])
|
||||
"""
|
||||
def get_chatbot_tools(*, user_id: str) -> list[str]:
|
||||
"""Get list of tool IDs that the chatbot can use for a given user."""
|
||||
registry = get_registry()
|
||||
return registry.list_tool_ids()
|
||||
|
||||
|
||||
def get_allowed_models(*, user_id: str) -> list[str]:
|
||||
"""Get list of AI models that a user is allowed to use.
|
||||
|
||||
This is a mock implementation that returns a fixed list of models.
|
||||
In production, this will query the database for user-specific model permissions.
|
||||
|
||||
Args:
|
||||
user_id: The unique identifier of the user
|
||||
|
||||
Returns:
|
||||
List of model identifiers (e.g., ["gpt-5", "claude-4-5"])
|
||||
"""
|
||||
return ["gpt-5", "claude-4-5"]
|
||||
def get_chatbot_model(*, user_id: str) -> str:
|
||||
"""Gets the chatbot model(s) a user is allowed to use."""
|
||||
return "claude_4_5"
|
||||
|
||||
|
||||
def get_system_prompt(*, user_id: str) -> str:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,23 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import Any, Dict, List, Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelChatbot import (
|
||||
ChatMessageRequest,
|
||||
MessageItem,
|
||||
ChatMessageResponse,
|
||||
ThreadSummary,
|
||||
ThreadListResponse,
|
||||
ThreadDetail,
|
||||
DeleteResponse,
|
||||
)
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
from modules.features.chatBot import service as chat_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -17,68 +27,53 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
# --- Pydantic models for requests and responses ---
|
||||
|
||||
|
||||
class ChatMessageRequest(BaseModel):
|
||||
"""Request model for posting a chat message"""
|
||||
|
||||
thread_id: Optional[str] = Field(
|
||||
None, description="Thread ID (creates new thread if not provided)"
|
||||
)
|
||||
message: str = Field(..., description="User message content")
|
||||
|
||||
|
||||
class MessageItem(BaseModel):
|
||||
"""Individual message in a thread"""
|
||||
|
||||
role: str = Field(..., description="Message role (user or assistant)")
|
||||
content: str = Field(..., description="Message content")
|
||||
timestamp: float = Field(..., description="Message timestamp (Unix timestamp)")
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
"""Response model for posting a chat message"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
messages: List[MessageItem] = Field(..., description="All messages in thread")
|
||||
|
||||
|
||||
class ThreadSummary(BaseModel):
|
||||
"""Summary of a chat thread for list view"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
created_at: float = Field(..., description="Thread creation timestamp")
|
||||
last_message: str = Field(..., description="Last message content")
|
||||
message_count: int = Field(..., description="Total number of messages")
|
||||
|
||||
|
||||
class ThreadListResponse(BaseModel):
|
||||
"""Response model for listing all threads"""
|
||||
|
||||
threads: List[ThreadSummary] = Field(..., description="List of thread summaries")
|
||||
|
||||
|
||||
class ThreadDetail(BaseModel):
|
||||
"""Detailed view of a single thread"""
|
||||
|
||||
thread_id: str = Field(..., description="Thread ID")
|
||||
created_at: float = Field(..., description="Thread creation timestamp")
|
||||
messages: List[MessageItem] = Field(
|
||||
..., description="All messages in chronological order"
|
||||
)
|
||||
|
||||
|
||||
class DeleteResponse(BaseModel):
|
||||
"""Response model for delete operations"""
|
||||
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
thread_id: str = Field(..., description="Deleted thread ID")
|
||||
|
||||
|
||||
# --- Actual endpoints for chatbot ---
|
||||
|
||||
|
||||
@router.post("/message/stream")
|
||||
@limiter.limit("30/minute")
|
||||
async def post_chat_message_stream(
|
||||
*,
|
||||
request: Request,
|
||||
message_request: ChatMessageRequest,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Post a message to a chat thread with streaming progress updates.
|
||||
Creates a new thread if thread_id is not provided.
|
||||
|
||||
Returns Server-Sent Events (SSE) stream with status updates and final response.
|
||||
"""
|
||||
try:
|
||||
# Generate or use existing thread_id
|
||||
thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}"
|
||||
|
||||
logger.info(
|
||||
f"User {currentUser.id} posted streaming message to thread {thread_id}"
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
chat_service.post_message_stream(
|
||||
thread_id=thread_id,
|
||||
message=message_request.message,
|
||||
user=currentUser,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error posting chat message: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to post message: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/message", response_model=ChatMessageResponse)
|
||||
@limiter.limit("30/minute")
|
||||
async def post_chat_message(
|
||||
|
|
@ -88,35 +83,31 @@ async def post_chat_message(
|
|||
currentUser: User = Depends(getCurrentUser),
|
||||
) -> ChatMessageResponse:
|
||||
"""
|
||||
Post a message to a chat thread and get assistant response.
|
||||
Post a message to a chat thread and get assistant response (non-streaming).
|
||||
Creates a new thread if thread_id is not provided.
|
||||
|
||||
This endpoint will later be connected to LangGraph's checkpointer.
|
||||
For streaming updates, use the /message/stream endpoint instead.
|
||||
"""
|
||||
try:
|
||||
# Generate or use existing thread_id
|
||||
thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}"
|
||||
|
||||
# Get current timestamp
|
||||
current_time = datetime.now().timestamp()
|
||||
|
||||
# Create dummy message history
|
||||
# In production, this will fetch from LangGraph's checkpointer
|
||||
messages = [
|
||||
MessageItem(
|
||||
role="user", content=message_request.message, timestamp=current_time
|
||||
),
|
||||
MessageItem(
|
||||
role="assistant",
|
||||
content=f"Echo: {message_request.message} (This is a dummy response. LangGraph integration pending.)",
|
||||
timestamp=current_time + 0.5,
|
||||
),
|
||||
]
|
||||
|
||||
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
|
||||
|
||||
return ChatMessageResponse(thread_id=thread_id, messages=messages)
|
||||
response = await chat_service.post_message(
|
||||
thread_id=thread_id,
|
||||
message=message_request.message,
|
||||
user=currentUser,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Permission error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error posting chat message: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ fastapi==0.104.1
|
|||
websockets==12.0
|
||||
uvicorn==0.23.2
|
||||
python-multipart==0.0.6
|
||||
httpx==0.25.0
|
||||
httpx>=0.25.2
|
||||
pydantic>=2.0.0 # Upgraded to v2 for LangChain compatibility
|
||||
email-validator==2.0.0 # Required by Pydantic for email validation
|
||||
slowapi==0.1.8 # For rate limiting
|
||||
|
|
@ -113,3 +113,7 @@ psycopg2-binary==2.9.9
|
|||
langchain==0.3.27
|
||||
langgraph==0.6.8
|
||||
langchain-core==0.3.77
|
||||
langchain-anthropic==0.3.1 # For Claude models
|
||||
psycopg[binary]==3.2.1 # For PostgreSQL async support (LangGraph checkpointer)
|
||||
psycopg-pool==3.2.1 # Connection pooling for PostgreSQL
|
||||
langgraph-checkpoint-postgres==2.0.24
|
||||
|
|
|
|||
Loading…
Reference in a new issue