feat: chatbot w/ streaming basics
This commit is contained in:
parent
8707203ac2
commit
4bfeded9d0
12 changed files with 1315 additions and 159 deletions
17
app.py
17
app.py
|
|
@ -240,9 +240,26 @@ instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
logger.info("Application is starting up")
|
logger.info("Application is starting up")
|
||||||
|
|
||||||
|
# Initialize LangGraph checkpointer
|
||||||
|
from modules.features.chatBot.utils.checkpointer import (
|
||||||
|
initialize_checkpointer,
|
||||||
|
close_checkpointer,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await initialize_checkpointer()
|
||||||
|
logger.info("LangGraph checkpointer initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize LangGraph checkpointer: {str(e)}")
|
||||||
|
# Continue startup even if checkpointer fails to initialize
|
||||||
|
|
||||||
eventManager.start()
|
eventManager.start()
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
eventManager.stop()
|
eventManager.stop()
|
||||||
|
await close_checkpointer()
|
||||||
logger.info("Application has been shut down")
|
logger.info("Application has been shut down")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,18 @@ import uuid
|
||||||
|
|
||||||
|
|
||||||
class ChatStat(BaseModel, ModelMixin):
|
class ChatStat(BaseModel, ModelMixin):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
id: str = Field(
|
||||||
workflowId: Optional[str] = Field(None, description="Foreign key to workflow (for workflow stats)")
|
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||||
messageId: Optional[str] = Field(None, description="Foreign key to message (for message stats)")
|
)
|
||||||
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
|
workflowId: Optional[str] = Field(
|
||||||
|
None, description="Foreign key to workflow (for workflow stats)"
|
||||||
|
)
|
||||||
|
messageId: Optional[str] = Field(
|
||||||
|
None, description="Foreign key to message (for message stats)"
|
||||||
|
)
|
||||||
|
processingTime: Optional[float] = Field(
|
||||||
|
None, description="Processing time in seconds"
|
||||||
|
)
|
||||||
tokenCount: Optional[int] = Field(None, description="Number of tokens processed")
|
tokenCount: Optional[int] = Field(None, description="Number of tokens processed")
|
||||||
bytesSent: Optional[int] = Field(None, description="Number of bytes sent")
|
bytesSent: Optional[int] = Field(None, description="Number of bytes sent")
|
||||||
bytesReceived: Optional[int] = Field(None, description="Number of bytes received")
|
bytesReceived: Optional[int] = Field(None, description="Number of bytes received")
|
||||||
|
|
@ -37,14 +45,23 @@ register_model_labels(
|
||||||
|
|
||||||
|
|
||||||
class ChatLog(BaseModel, ModelMixin):
|
class ChatLog(BaseModel, ModelMixin):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
id: str = Field(
|
||||||
|
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||||
|
)
|
||||||
workflowId: str = Field(description="Foreign key to workflow")
|
workflowId: str = Field(description="Foreign key to workflow")
|
||||||
message: str = Field(description="Log message")
|
message: str = Field(description="Log message")
|
||||||
type: str = Field(description="Log type (info, warning, error, etc.)")
|
type: str = Field(description="Log type (info, warning, error, etc.)")
|
||||||
timestamp: float = Field(default_factory=get_utc_timestamp, description="When the log entry was created (UTC timestamp in seconds)")
|
timestamp: float = Field(
|
||||||
|
default_factory=get_utc_timestamp,
|
||||||
|
description="When the log entry was created (UTC timestamp in seconds)",
|
||||||
|
)
|
||||||
status: Optional[str] = Field(None, description="Status of the log entry")
|
status: Optional[str] = Field(None, description="Status of the log entry")
|
||||||
progress: Optional[float] = Field(None, description="Progress indicator (0.0 to 1.0)")
|
progress: Optional[float] = Field(
|
||||||
performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics")
|
None, description="Progress indicator (0.0 to 1.0)"
|
||||||
|
)
|
||||||
|
performance: Optional[Dict[str, Any]] = Field(
|
||||||
|
None, description="Performance metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -64,7 +81,9 @@ register_model_labels(
|
||||||
|
|
||||||
|
|
||||||
class ChatDocument(BaseModel, ModelMixin):
|
class ChatDocument(BaseModel, ModelMixin):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
id: str = Field(
|
||||||
|
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||||
|
)
|
||||||
messageId: str = Field(description="Foreign key to message")
|
messageId: str = Field(description="Foreign key to message")
|
||||||
fileId: str = Field(description="Foreign key to file")
|
fileId: str = Field(description="Foreign key to file")
|
||||||
fileName: str = Field(description="Name of the file")
|
fileName: str = Field(description="Name of the file")
|
||||||
|
|
@ -73,7 +92,9 @@ class ChatDocument(BaseModel, ModelMixin):
|
||||||
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
|
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
|
||||||
taskNumber: Optional[int] = Field(None, description="Task number within round")
|
taskNumber: Optional[int] = Field(None, description="Task number within round")
|
||||||
actionNumber: Optional[int] = Field(None, description="Action number within task")
|
actionNumber: Optional[int] = Field(None, description="Action number within task")
|
||||||
actionId: Optional[str] = Field(None, description="ID of the action that created this document")
|
actionId: Optional[str] = Field(
|
||||||
|
None, description="ID of the action that created this document"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -96,13 +117,19 @@ register_model_labels(
|
||||||
|
|
||||||
class ContentMetadata(BaseModel, ModelMixin):
|
class ContentMetadata(BaseModel, ModelMixin):
|
||||||
size: int = Field(description="Content size in bytes")
|
size: int = Field(description="Content size in bytes")
|
||||||
pages: Optional[int] = Field(None, description="Number of pages for multi-page content")
|
pages: Optional[int] = Field(
|
||||||
|
None, description="Number of pages for multi-page content"
|
||||||
|
)
|
||||||
error: Optional[str] = Field(None, description="Processing error if any")
|
error: Optional[str] = Field(None, description="Processing error if any")
|
||||||
width: Optional[int] = Field(None, description="Width in pixels for images/videos")
|
width: Optional[int] = Field(None, description="Width in pixels for images/videos")
|
||||||
height: Optional[int] = Field(None, description="Height in pixels for images/videos")
|
height: Optional[int] = Field(
|
||||||
|
None, description="Height in pixels for images/videos"
|
||||||
|
)
|
||||||
colorMode: Optional[str] = Field(None, description="Color mode")
|
colorMode: Optional[str] = Field(None, description="Color mode")
|
||||||
fps: Optional[float] = Field(None, description="Frames per second for videos")
|
fps: Optional[float] = Field(None, description="Frames per second for videos")
|
||||||
durationSec: Optional[float] = Field(None, description="Duration in seconds for media")
|
durationSec: Optional[float] = Field(
|
||||||
|
None, description="Duration in seconds for media"
|
||||||
|
)
|
||||||
mimeType: str = Field(description="MIME type of the content")
|
mimeType: str = Field(description="MIME type of the content")
|
||||||
base64Encoded: bool = Field(description="Whether the data is base64 encoded")
|
base64Encoded: bool = Field(description="Whether the data is base64 encoded")
|
||||||
|
|
||||||
|
|
@ -144,7 +171,9 @@ register_model_labels(
|
||||||
|
|
||||||
class ExtractedContent(BaseModel, ModelMixin):
|
class ExtractedContent(BaseModel, ModelMixin):
|
||||||
id: str = Field(description="Reference to source ChatDocument")
|
id: str = Field(description="Reference to source ChatDocument")
|
||||||
contents: List[ContentItem] = Field(default_factory=list, description="List of content items")
|
contents: List[ContentItem] = Field(
|
||||||
|
default_factory=list, description="List of content items"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -156,27 +185,53 @@ register_model_labels(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel, ModelMixin):
|
class ChatMessage(BaseModel, ModelMixin):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
id: str = Field(
|
||||||
|
default_factory=lambda: str(uuid.uuid4()), description="Primary key"
|
||||||
|
)
|
||||||
workflowId: str = Field(description="Foreign key to workflow")
|
workflowId: str = Field(description="Foreign key to workflow")
|
||||||
parentMessageId: Optional[str] = Field(None, description="Parent message ID for threading")
|
parentMessageId: Optional[str] = Field(
|
||||||
documents: List[ChatDocument] = Field(default_factory=list, description="Associated documents")
|
None, description="Parent message ID for threading"
|
||||||
documentsLabel: Optional[str] = Field(None, description="Label for the set of documents")
|
)
|
||||||
|
documents: List[ChatDocument] = Field(
|
||||||
|
default_factory=list, description="Associated documents"
|
||||||
|
)
|
||||||
|
documentsLabel: Optional[str] = Field(
|
||||||
|
None, description="Label for the set of documents"
|
||||||
|
)
|
||||||
message: Optional[str] = Field(None, description="Message content")
|
message: Optional[str] = Field(None, description="Message content")
|
||||||
role: str = Field(description="Role of the message sender")
|
role: str = Field(description="Role of the message sender")
|
||||||
status: str = Field(description="Status of the message (first, step, last)")
|
status: str = Field(description="Status of the message (first, step, last)")
|
||||||
sequenceNr: int = Field(description="Sequence number of the message (set automatically)")
|
sequenceNr: int = Field(
|
||||||
publishedAt: float = Field(default_factory=get_utc_timestamp, description="When the message was published (UTC timestamp in seconds)")
|
description="Sequence number of the message (set automatically)"
|
||||||
|
)
|
||||||
|
publishedAt: float = Field(
|
||||||
|
default_factory=get_utc_timestamp,
|
||||||
|
description="When the message was published (UTC timestamp in seconds)",
|
||||||
|
)
|
||||||
stats: Optional[ChatStat] = Field(None, description="Statistics for this message")
|
stats: Optional[ChatStat] = Field(None, description="Statistics for this message")
|
||||||
success: Optional[bool] = Field(None, description="Whether the message processing was successful")
|
success: Optional[bool] = Field(
|
||||||
actionId: Optional[str] = Field(None, description="ID of the action that produced this message")
|
None, description="Whether the message processing was successful"
|
||||||
actionMethod: Optional[str] = Field(None, description="Method of the action that produced this message")
|
)
|
||||||
actionName: Optional[str] = Field(None, description="Name of the action that produced this message")
|
actionId: Optional[str] = Field(
|
||||||
|
None, description="ID of the action that produced this message"
|
||||||
|
)
|
||||||
|
actionMethod: Optional[str] = Field(
|
||||||
|
None, description="Method of the action that produced this message"
|
||||||
|
)
|
||||||
|
actionName: Optional[str] = Field(
|
||||||
|
None, description="Name of the action that produced this message"
|
||||||
|
)
|
||||||
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
|
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
|
||||||
taskNumber: Optional[int] = Field(None, description="Task number within round")
|
taskNumber: Optional[int] = Field(None, description="Task number within round")
|
||||||
actionNumber: Optional[int] = Field(None, description="Action number within task")
|
actionNumber: Optional[int] = Field(None, description="Action number within task")
|
||||||
taskProgress: Optional[str] = Field(None, description="Task progress status: pending, running, success, fail, retry")
|
taskProgress: Optional[str] = Field(
|
||||||
actionProgress: Optional[str] = Field(None, description="Action progress status: pending, running, success, fail")
|
None, description="Task progress status: pending, running, success, fail, retry"
|
||||||
|
)
|
||||||
|
actionProgress: Optional[str] = Field(
|
||||||
|
None, description="Action progress status: pending, running, success, fail"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -208,31 +263,135 @@ register_model_labels(
|
||||||
|
|
||||||
|
|
||||||
class ChatWorkflow(BaseModel, ModelMixin):
|
class ChatWorkflow(BaseModel, ModelMixin):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
id: str = Field(
|
||||||
mandateId: str = Field(description="ID of the mandate this workflow belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
default_factory=lambda: str(uuid.uuid4()),
|
||||||
status: str = Field(description="Current status of the workflow", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
|
description="Primary key",
|
||||||
{"value": "running", "label": {"en": "Running", "fr": "En cours"}},
|
frontend_type="text",
|
||||||
{"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}},
|
frontend_readonly=True,
|
||||||
{"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}},
|
frontend_required=False,
|
||||||
{"value": "error", "label": {"en": "Error", "fr": "Erreur"}},
|
)
|
||||||
])
|
mandateId: str = Field(
|
||||||
name: Optional[str] = Field(None, description="Name of the workflow", frontend_type="text", frontend_readonly=False, frontend_required=True)
|
description="ID of the mandate this workflow belongs to",
|
||||||
currentRound: int = Field(description="Current round number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
frontend_type="text",
|
||||||
currentTask: int = Field(default=0, description="Current task number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
frontend_readonly=True,
|
||||||
currentAction: int = Field(default=0, description="Current action number", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
frontend_required=False,
|
||||||
totalTasks: int = Field(default=0, description="Total number of tasks in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
)
|
||||||
totalActions: int = Field(default=0, description="Total number of actions in the workflow", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
status: str = Field(
|
||||||
lastActivity: float = Field(default_factory=get_utc_timestamp, description="Timestamp of last activity (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
description="Current status of the workflow",
|
||||||
startedAt: float = Field(default_factory=get_utc_timestamp, description="When the workflow started (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
frontend_type="select",
|
||||||
logs: List[ChatLog] = Field(default_factory=list, description="Workflow logs", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
frontend_readonly=False,
|
||||||
messages: List[ChatMessage] = Field(default_factory=list, description="Messages in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
frontend_required=False,
|
||||||
stats: Optional[ChatStat] = Field(None, description="Workflow statistics", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
frontend_options=[
|
||||||
tasks: list = Field(default_factory=list, description="List of tasks in the workflow", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
{"value": "running", "label": {"en": "Running", "fr": "En cours"}},
|
||||||
workflowMode: str = Field(default="Actionplan", description="Workflow mode selector", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
|
{"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}},
|
||||||
{"value": "Actionplan", "label": {"en": "Action Plan", "fr": "Plan d'actions"}},
|
{"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}},
|
||||||
{"value": "React", "label": {"en": "React", "fr": "Réactif"}},
|
{"value": "error", "label": {"en": "Error", "fr": "Erreur"}},
|
||||||
])
|
],
|
||||||
maxSteps: int = Field(default=5, description="Maximum number of iterations in react mode", frontend_type="integer", frontend_readonly=False, frontend_required=False)
|
)
|
||||||
|
name: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Name of the workflow",
|
||||||
|
frontend_type="text",
|
||||||
|
frontend_readonly=False,
|
||||||
|
frontend_required=True,
|
||||||
|
)
|
||||||
|
currentRound: int = Field(
|
||||||
|
description="Current round number",
|
||||||
|
frontend_type="integer",
|
||||||
|
frontend_readonly=True,
|
||||||
|
frontend_required=False,
|
||||||
|
)
|
||||||
|
currentTask: int = Field(
|
||||||
|
default=0,
|
||||||
|
description="Current task number",
|
||||||
|
frontend_type="integer",
|
||||||
|
frontend_readonly=True,
|
||||||
|
frontend_required=False,
|
||||||
|
)
|
||||||
|
currentAction: int = Field(
|
||||||
|
default=0,
|
||||||
|
description="Current action number",
|
||||||
|
frontend_type="integer",
|
||||||
|
frontend_readonly=True,
|
||||||
|
frontend_required=False,
|
||||||
|
)
|
||||||
|
totalTasks: int = Field(
|
||||||
|
default=0,
|
||||||
|
description="Total number of tasks in the workflow",
|
||||||
|
frontend_type="integer",
|
||||||
|
frontend_readonly=True,
|
||||||
|
frontend_required=False,
|
||||||
|
)
|
||||||
|
totalActions: int = Field(
|
||||||
|
default=0,
|
||||||
|
description="Total number of actions in the workflow",
|
||||||
|
frontend_type="integer",
|
||||||
|
frontend_readonly=True,
|
||||||
|
frontend_required=False,
|
||||||
|
)
|
||||||
|
lastActivity: float = Field(
|
||||||
|
default_factory=get_utc_timestamp,
|
||||||
|
description="Timestamp of last activity (UTC timestamp in seconds)",
|
||||||
|
frontend_type="timestamp",
|
||||||
|
frontend_readonly=True,
|
||||||
|
frontend_required=False,
|
||||||
|
)
|
||||||
|
startedAt: float = Field(
|
||||||
|
default_factory=get_utc_timestamp,
|
||||||
|
description="When the workflow started (UTC timestamp in seconds)",
|
||||||
|
frontend_type="timestamp",
|
||||||
|
frontend_readonly=True,
|
||||||
|
frontend_required=False,
|
||||||
|
)
|
||||||
|
logs: List[ChatLog] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Workflow logs",
|
||||||
|
frontend_type="text",
|
||||||
|
frontend_readonly=True,
|
||||||
|
frontend_required=False,
|
||||||
|
)
|
||||||
|
messages: List[ChatMessage] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Messages in the workflow",
|
||||||
|
frontend_type="text",
|
||||||
|
frontend_readonly=True,
|
||||||
|
frontend_required=False,
|
||||||
|
)
|
||||||
|
stats: Optional[ChatStat] = Field(
|
||||||
|
None,
|
||||||
|
description="Workflow statistics",
|
||||||
|
frontend_type="text",
|
||||||
|
frontend_readonly=True,
|
||||||
|
frontend_required=False,
|
||||||
|
)
|
||||||
|
tasks: list = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="List of tasks in the workflow",
|
||||||
|
frontend_type="text",
|
||||||
|
frontend_readonly=True,
|
||||||
|
frontend_required=False,
|
||||||
|
)
|
||||||
|
workflowMode: str = Field(
|
||||||
|
default="Actionplan",
|
||||||
|
description="Workflow mode selector",
|
||||||
|
frontend_type="select",
|
||||||
|
frontend_readonly=False,
|
||||||
|
frontend_required=False,
|
||||||
|
frontend_options=[
|
||||||
|
{
|
||||||
|
"value": "Actionplan",
|
||||||
|
"label": {"en": "Action Plan", "fr": "Plan d'actions"},
|
||||||
|
},
|
||||||
|
{"value": "React", "label": {"en": "React", "fr": "Réactif"}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
maxSteps: int = Field(
|
||||||
|
default=5,
|
||||||
|
description="Maximum number of iterations in react mode",
|
||||||
|
frontend_type="integer",
|
||||||
|
frontend_readonly=False,
|
||||||
|
frontend_required=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_labels(
|
register_model_labels(
|
||||||
|
|
@ -278,7 +437,10 @@ register_model_labels(
|
||||||
"completed_tasks": {"en": "Completed Tasks", "fr": "Tâches terminées"},
|
"completed_tasks": {"en": "Completed Tasks", "fr": "Tâches terminées"},
|
||||||
"total_tasks": {"en": "Total Tasks", "fr": "Total des tâches"},
|
"total_tasks": {"en": "Total Tasks", "fr": "Total des tâches"},
|
||||||
"execution_time": {"en": "Execution Time", "fr": "Temps d'exécution"},
|
"execution_time": {"en": "Execution Time", "fr": "Temps d'exécution"},
|
||||||
"final_results_count": {"en": "Final Results Count", "fr": "Nombre de résultats finaux"},
|
"final_results_count": {
|
||||||
|
"en": "Final Results Count",
|
||||||
|
"fr": "Nombre de résultats finaux",
|
||||||
|
},
|
||||||
"error": {"en": "Error", "fr": "Erreur"},
|
"error": {"en": "Error", "fr": "Erreur"},
|
||||||
"phase": {"en": "Phase", "fr": "Phase"},
|
"phase": {"en": "Phase", "fr": "Phase"},
|
||||||
},
|
},
|
||||||
|
|
@ -300,5 +462,3 @@ register_model_labels(
|
||||||
"userLanguage": {"en": "User Language", "fr": "Langue de l'utilisateur"},
|
"userLanguage": {"en": "User Language", "fr": "Langue de l'utilisateur"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
130
modules/datamodels/datamodelChatbot.py
Normal file
130
modules/datamodels/datamodelChatbot.py
Normal file
|
|
@ -0,0 +1,130 @@
|
||||||
|
"""Chatbot API models for request/response handling."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||||
|
|
||||||
|
|
||||||
|
# Chatbot API Models
|
||||||
|
class MessageItem(BaseModel, ModelMixin):
|
||||||
|
"""Individual message in a thread"""
|
||||||
|
|
||||||
|
role: str = Field(..., description="Message role (user or assistant)")
|
||||||
|
content: str = Field(..., description="Message content")
|
||||||
|
timestamp: float = Field(..., description="Message timestamp (Unix timestamp)")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageRequest(BaseModel, ModelMixin):
|
||||||
|
"""Request model for posting a chat message"""
|
||||||
|
|
||||||
|
thread_id: Optional[str] = Field(
|
||||||
|
None, description="Thread ID (creates new thread if not provided)"
|
||||||
|
)
|
||||||
|
message: str = Field(..., description="User message content")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageResponse(BaseModel, ModelMixin):
|
||||||
|
"""Response model for posting a chat message"""
|
||||||
|
|
||||||
|
thread_id: str = Field(..., description="Thread ID")
|
||||||
|
messages: List[MessageItem] = Field(..., description="All messages in thread")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSummary(BaseModel, ModelMixin):
|
||||||
|
"""Summary of a chat thread for list view"""
|
||||||
|
|
||||||
|
thread_id: str = Field(..., description="Thread ID")
|
||||||
|
created_at: float = Field(..., description="Thread creation timestamp")
|
||||||
|
last_message: str = Field(..., description="Last message content")
|
||||||
|
message_count: int = Field(..., description="Total number of messages")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadListResponse(BaseModel, ModelMixin):
|
||||||
|
"""Response model for listing all threads"""
|
||||||
|
|
||||||
|
threads: List[ThreadSummary] = Field(..., description="List of thread summaries")
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadDetail(BaseModel, ModelMixin):
|
||||||
|
"""Detailed view of a single thread"""
|
||||||
|
|
||||||
|
thread_id: str = Field(..., description="Thread ID")
|
||||||
|
created_at: float = Field(..., description="Thread creation timestamp")
|
||||||
|
messages: List[MessageItem] = Field(
|
||||||
|
..., description="All messages in chronological order"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteResponse(BaseModel, ModelMixin):
|
||||||
|
"""Response model for delete operations"""
|
||||||
|
|
||||||
|
message: str = Field(..., description="Confirmation message")
|
||||||
|
thread_id: str = Field(..., description="Deleted thread ID")
|
||||||
|
|
||||||
|
|
||||||
|
# Register model labels for internationalization
|
||||||
|
register_model_labels(
|
||||||
|
"MessageItem",
|
||||||
|
{"en": "Message Item", "fr": "Élément de message"},
|
||||||
|
{
|
||||||
|
"role": {"en": "Role", "fr": "Rôle"},
|
||||||
|
"content": {"en": "Content", "fr": "Contenu"},
|
||||||
|
"timestamp": {"en": "Timestamp", "fr": "Horodatage"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
register_model_labels(
|
||||||
|
"ChatMessageRequest",
|
||||||
|
{"en": "Chat Message Request", "fr": "Demande de message de chat"},
|
||||||
|
{
|
||||||
|
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||||
|
"message": {"en": "Message", "fr": "Message"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
register_model_labels(
|
||||||
|
"ChatMessageResponse",
|
||||||
|
{"en": "Chat Message Response", "fr": "Réponse du message de chat"},
|
||||||
|
{
|
||||||
|
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||||
|
"messages": {"en": "Messages", "fr": "Messages"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
register_model_labels(
|
||||||
|
"ThreadSummary",
|
||||||
|
{"en": "Thread Summary", "fr": "Résumé du fil"},
|
||||||
|
{
|
||||||
|
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||||
|
"created_at": {"en": "Created At", "fr": "Créé le"},
|
||||||
|
"last_message": {"en": "Last Message", "fr": "Dernier message"},
|
||||||
|
"message_count": {"en": "Message Count", "fr": "Nombre de messages"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
register_model_labels(
|
||||||
|
"ThreadListResponse",
|
||||||
|
{"en": "Thread List Response", "fr": "Réponse de liste de fils"},
|
||||||
|
{
|
||||||
|
"threads": {"en": "Threads", "fr": "Fils"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
register_model_labels(
|
||||||
|
"ThreadDetail",
|
||||||
|
{"en": "Thread Detail", "fr": "Détail du fil"},
|
||||||
|
{
|
||||||
|
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||||
|
"created_at": {"en": "Created At", "fr": "Créé le"},
|
||||||
|
"messages": {"en": "Messages", "fr": "Messages"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
register_model_labels(
|
||||||
|
"DeleteResponse",
|
||||||
|
{"en": "Delete Response", "fr": "Réponse de suppression"},
|
||||||
|
{
|
||||||
|
"message": {"en": "Message", "fr": "Message"},
|
||||||
|
"thread_id": {"en": "Thread ID", "fr": "ID du fil"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
"""Tool for sending streaming status updates to users."""
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def send_streaming_message(message: str) -> str:
|
||||||
|
"""Send a streaming message to the user to provide updates during processing.
|
||||||
|
|
||||||
|
Use this tool to send short status updates to the user while you are working
|
||||||
|
on their request. This helps keep the user informed about what you are doing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: A short message describing what you are currently doing.
|
||||||
|
Examples: "Searching database for relevant information..."
|
||||||
|
"Analyzing search results..."
|
||||||
|
"Processing your request..."
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A confirmation that the message was sent.
|
||||||
|
"""
|
||||||
|
# This tool doesn't actually do anything - it's just for the AI to signal
|
||||||
|
# what it's doing to the frontend via the tool call mechanism
|
||||||
|
return f"Status update sent: {message}"
|
||||||
1
modules/features/chatBot/domain/__init__.py
Normal file
1
modules/features/chatBot/domain/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Domain logic for chatbot functionality."""
|
||||||
301
modules/features/chatBot/domain/chatbot.py
Normal file
301
modules/features/chatBot/domain/chatbot.py
Normal file
|
|
@ -0,0 +1,301 @@
|
||||||
|
"""Chatbot domain logic with LangGraph integration."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Annotated, AsyncIterator, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from langchain_core.messages import (
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
trim_messages,
|
||||||
|
)
|
||||||
|
from langgraph.graph.message import add_messages
|
||||||
|
from langgraph.graph import StateGraph, START, END
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
from langgraph.prebuilt import ToolNode
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
|
||||||
|
from modules.features.chatBot.domain.streaming_helper import ChatStreamingHelper
|
||||||
|
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||||
|
from modules.shared.configuration import APP_CONFIG
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatState(BaseModel):
|
||||||
|
"""Represents the state of a chat session."""
|
||||||
|
|
||||||
|
messages: Annotated[list[BaseMessage], add_messages]
|
||||||
|
|
||||||
|
|
||||||
|
def get_langchain_model(*, model_name: str) -> ChatAnthropic:
|
||||||
|
"""Map permission model names to LangChain ChatAnthropic models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: The model name from permissions (e.g., "claude_4_5")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured ChatAnthropic instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model name is not supported
|
||||||
|
"""
|
||||||
|
# Model name mapping
|
||||||
|
model_mapping = {
|
||||||
|
"claude_4_5": "claude-4-5-sonnet",
|
||||||
|
# Add more mappings as needed
|
||||||
|
}
|
||||||
|
|
||||||
|
anthropic_model = model_mapping.get(model_name)
|
||||||
|
if not anthropic_model:
|
||||||
|
logger.warning(
|
||||||
|
f"Unknown model name '{model_name}', defaulting to claude-4-5-sonnet"
|
||||||
|
)
|
||||||
|
anthropic_model = "claude-4-5-sonnet"
|
||||||
|
|
||||||
|
return ChatAnthropic(
|
||||||
|
model=anthropic_model,
|
||||||
|
api_key=APP_CONFIG.get("Connector_AiAnthropic_API_SECRET"),
|
||||||
|
temperature=float(APP_CONFIG.get("Connector_AiAnthropic_TEMPERATURE", 0.2)),
|
||||||
|
max_tokens=int(APP_CONFIG.get("Connector_AiAnthropic_MAX_TOKENS", 2000)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Chatbot:
|
||||||
|
"""Represents a chatbot with LangGraph integration."""
|
||||||
|
|
||||||
|
model: Any
|
||||||
|
memory: Any
|
||||||
|
app: Any = None
|
||||||
|
system_prompt: str = "You are a helpful assistant."
|
||||||
|
context_window_size: int = 100000
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
model: Any,
|
||||||
|
memory: Any,
|
||||||
|
system_prompt: str,
|
||||||
|
tools: list,
|
||||||
|
context_window_size: int = 100000,
|
||||||
|
) -> "Chatbot":
|
||||||
|
"""Factory method to create and configure a Chatbot instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The chat model to use.
|
||||||
|
memory: The chat memory checkpointer to use.
|
||||||
|
system_prompt: The system prompt to initialize the chatbot.
|
||||||
|
tools: List of LangChain tools the chatbot can use.
|
||||||
|
context_window_size: Maximum tokens for context window.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured Chatbot instance.
|
||||||
|
"""
|
||||||
|
instance = cls(
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
context_window_size=context_window_size,
|
||||||
|
)
|
||||||
|
instance.app = instance._build_app(memory=memory, tools=tools)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def _build_app(
|
||||||
|
self, *, memory: Any, tools: list
|
||||||
|
) -> CompiledStateGraph[ChatState, None, ChatState, ChatState]:
|
||||||
|
"""Builds the chatbot application workflow using LangGraph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory: The chat memory checkpointer to use.
|
||||||
|
tools: The list of tools the chatbot can use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A compiled state graph representing the chatbot application.
|
||||||
|
"""
|
||||||
|
llm_with_tools = self.model.bind_tools(tools=tools)
|
||||||
|
|
||||||
|
def select_window(msgs: list[BaseMessage]) -> list[BaseMessage]:
|
||||||
|
"""Selects a window of messages that fit within the context window size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msgs: The list of messages to select from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of messages that fit within the context window size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def approx_counter(items: list[BaseMessage]) -> int:
|
||||||
|
"""Approximate token counter for messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: List of messages to count tokens for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Approximate number of tokens in the messages.
|
||||||
|
"""
|
||||||
|
return sum(len(getattr(m, "content", "") or "") for m in items)
|
||||||
|
|
||||||
|
return trim_messages(
|
||||||
|
msgs,
|
||||||
|
strategy="last",
|
||||||
|
token_counter=approx_counter,
|
||||||
|
max_tokens=self.context_window_size,
|
||||||
|
start_on="human",
|
||||||
|
end_on=("human", "tool"),
|
||||||
|
include_system=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def agent_node(state: ChatState) -> dict:
|
||||||
|
"""Agent node for the chatbot workflow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The current chat state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated chat state after processing.
|
||||||
|
"""
|
||||||
|
# Select the message window to fit in context (trim if needed)
|
||||||
|
window = select_window(state.messages)
|
||||||
|
|
||||||
|
# Ensure the system prompt is present at the start
|
||||||
|
if not window or not isinstance(window[0], SystemMessage):
|
||||||
|
window = [SystemMessage(content=self.system_prompt)] + window
|
||||||
|
|
||||||
|
# Call the LLM with tools
|
||||||
|
response = llm_with_tools.invoke(window)
|
||||||
|
|
||||||
|
# Return the new state
|
||||||
|
return {"messages": [response]}
|
||||||
|
|
||||||
|
def should_continue(state: ChatState) -> str:
|
||||||
|
"""Determines whether to continue the workflow or end it.
|
||||||
|
|
||||||
|
This conditional edge is called after the agent node to decide
|
||||||
|
whether to continue to the tools node (if the last message contains
|
||||||
|
tool calls) or to end the workflow (if no tool calls are present).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The current chat state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next node to transition to ("tools" or END).
|
||||||
|
"""
|
||||||
|
# Get the last message
|
||||||
|
last_message = state.messages[-1]
|
||||||
|
|
||||||
|
# Check if the last message contains tool calls
|
||||||
|
# If so, continue to the tools node; otherwise, end the workflow
|
||||||
|
return "tools" if getattr(last_message, "tool_calls", None) else END
|
||||||
|
|
||||||
|
# Compose the workflow
|
||||||
|
workflow = StateGraph(ChatState)
|
||||||
|
workflow.add_node("agent", agent_node)
|
||||||
|
workflow.add_node("tools", ToolNode(tools=tools))
|
||||||
|
workflow.add_edge(START, "agent")
|
||||||
|
workflow.add_conditional_edges("agent", should_continue)
|
||||||
|
workflow.add_edge("tools", "agent")
|
||||||
|
return workflow.compile(checkpointer=memory)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self, *, message: str, chat_id: str = "default"
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
"""Processes a chat message and returns the chat history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The user message to process.
|
||||||
|
chat_id: The chat thread ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of messages in the chat history.
|
||||||
|
"""
|
||||||
|
# Set the right thread ID for memory
|
||||||
|
config = {"configurable": {"thread_id": chat_id}}
|
||||||
|
|
||||||
|
# Single-turn chat (non-streaming)
|
||||||
|
result = await self.app.ainvoke(
|
||||||
|
{"messages": [HumanMessage(content=message)]}, config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract and return the messages from the result
|
||||||
|
return result["messages"]
|
||||||
|
|
||||||
|
async def stream_events(
|
||||||
|
self, *, message: str, chat_id: str = "default"
|
||||||
|
) -> AsyncIterator[dict]:
|
||||||
|
"""Stream UI-focused events using astream_events v2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The user message to process.
|
||||||
|
chat_id: Logical thread identifier; forwarded in the runnable config so
|
||||||
|
memory and tools are scoped per thread.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
dict: One of:
|
||||||
|
- ``{"type": "status", "label": str}`` for short progress updates.
|
||||||
|
- ``{"type": "final", "response": {"thread": str, "chat_history": list[dict]}}``
|
||||||
|
where ``chat_history`` only includes ``user``/``assistant`` roles.
|
||||||
|
- ``{"type": "error", "message": str}`` if an exception occurs.
|
||||||
|
"""
|
||||||
|
# Thread-aware config for LangGraph/LangChain
|
||||||
|
config = {"configurable": {"thread_id": chat_id}}
|
||||||
|
|
||||||
|
def _is_root(ev: dict) -> bool:
|
||||||
|
"""Return True if the event is from the root run (v2: empty parent_ids)."""
|
||||||
|
return not ev.get("parent_ids")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for event in self.app.astream_events(
|
||||||
|
{"messages": [HumanMessage(content=message)]},
|
||||||
|
config=config,
|
||||||
|
version="v2",
|
||||||
|
):
|
||||||
|
etype = event.get("event")
|
||||||
|
ename = event.get("name") or ""
|
||||||
|
edata = event.get("data") or {}
|
||||||
|
|
||||||
|
# Stream human-readable progress via the special send_streaming_message tool
|
||||||
|
if etype == "on_tool_start" and ename == "send_streaming_message":
|
||||||
|
tool_in = edata.get("input") or {}
|
||||||
|
msg = tool_in.get("message")
|
||||||
|
if isinstance(msg, str) and msg.strip():
|
||||||
|
yield {"type": "status", "label": msg.strip()}
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Emit the final payload when the root run finishes
|
||||||
|
if etype == "on_chain_end" and _is_root(event):
|
||||||
|
output_obj = edata.get("output")
|
||||||
|
|
||||||
|
# Extract message list from the graph's final output
|
||||||
|
final_msgs = ChatStreamingHelper.extract_messages_from_output(
|
||||||
|
output_obj=output_obj
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize for the frontend (only user/assistant with text content)
|
||||||
|
chat_history_payload: list[dict] = []
|
||||||
|
for m in final_msgs:
|
||||||
|
if isinstance(m, BaseMessage):
|
||||||
|
d = ChatStreamingHelper.message_to_dict(msg=m)
|
||||||
|
elif isinstance(m, dict):
|
||||||
|
d = ChatStreamingHelper.dict_message_to_dict(obj=m)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if d.get("role") in ("user", "assistant") and d.get("content"):
|
||||||
|
chat_history_payload.append(d)
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"type": "final",
|
||||||
|
"response": {
|
||||||
|
"thread": chat_id,
|
||||||
|
"chat_history": chat_history_payload,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
# Emit a single error envelope and end the stream
|
||||||
|
logger.error(f"Error in stream_events: {str(exc)}", exc_info=True)
|
||||||
|
yield {"type": "error", "message": f"Error processing request: {exc}"}
|
||||||
239
modules/features/chatBot/domain/streaming_helper.py
Normal file
239
modules/features/chatBot/domain/streaming_helper.py
Normal file
|
|
@ -0,0 +1,239 @@
|
||||||
|
"""Streaming helper utilities for chat message processing and normalization."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Literal, Mapping, Optional
|
||||||
|
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
Role = Literal["user", "assistant", "system", "tool"]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatStreamingHelper:
|
||||||
|
"""Pure helper methods for streaming and message normalization.
|
||||||
|
|
||||||
|
This class provides static utility methods for converting between different
|
||||||
|
message formats, extracting content, and normalizing message structures
|
||||||
|
for streaming chat applications.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def role_from_message(*, msg: BaseMessage) -> Role:
|
||||||
|
"""Extract the role from a BaseMessage instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: The BaseMessage instance to extract the role from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The role as a string literal: "user", "assistant", "system", or "tool".
|
||||||
|
Defaults to "assistant" if the message type is not recognized.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from langchain_core.messages import HumanMessage
|
||||||
|
>>> msg = HumanMessage(content="Hello")
|
||||||
|
>>> ChatStreamingHelper.role_from_message(msg=msg)
|
||||||
|
'user'
|
||||||
|
"""
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
return "user"
|
||||||
|
if isinstance(msg, AIMessage):
|
||||||
|
return "assistant"
|
||||||
|
if isinstance(msg, SystemMessage):
|
||||||
|
return "system"
|
||||||
|
if isinstance(msg, ToolMessage):
|
||||||
|
return "tool"
|
||||||
|
return getattr(msg, "role", "assistant")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def flatten_content(*, content: Any) -> str:
|
||||||
|
"""Convert complex content structures to plain text.
|
||||||
|
|
||||||
|
This method handles various content formats including strings, lists of
|
||||||
|
content parts, and dictionaries with text fields. It's designed to
|
||||||
|
normalize content from different message sources into a consistent
|
||||||
|
plain text format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The content to flatten. Can be:
|
||||||
|
- str: Returned as-is after stripping whitespace
|
||||||
|
- list: Each item processed and joined with newlines
|
||||||
|
- dict: Text extracted from "text" or "content" fields
|
||||||
|
- None: Returns empty string
|
||||||
|
- Any other type: Converted to string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The flattened content as a plain text string with whitespace stripped.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]
|
||||||
|
>>> ChatStreamingHelper.flatten_content(content=content)
|
||||||
|
'Hello
|
||||||
|
nworld'
|
||||||
|
|
||||||
|
>>> content = {"text": "Simple message"}
|
||||||
|
>>> ChatStreamingHelper.flatten_content(content=content)
|
||||||
|
'Simple message'
|
||||||
|
"""
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content.strip()
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: List[str] = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict):
|
||||||
|
if "text" in part and isinstance(part["text"], str):
|
||||||
|
parts.append(part["text"])
|
||||||
|
elif part.get("type") == "text" and isinstance(
|
||||||
|
part.get("text"), str
|
||||||
|
):
|
||||||
|
parts.append(part["text"])
|
||||||
|
elif "content" in part and isinstance(part["content"], str):
|
||||||
|
parts.append(part["content"])
|
||||||
|
else:
|
||||||
|
# Fallback for unknown dictionary structures
|
||||||
|
val = part.get("value")
|
||||||
|
if isinstance(val, str):
|
||||||
|
parts.append(val)
|
||||||
|
else:
|
||||||
|
parts.append(str(part))
|
||||||
|
return "\n".join(p.strip() for p in parts if p is not None)
|
||||||
|
if isinstance(content, dict):
|
||||||
|
if "text" in content and isinstance(content["text"], str):
|
||||||
|
return content["text"].strip()
|
||||||
|
if "content" in content and isinstance(content["content"], str):
|
||||||
|
return content["content"].strip()
|
||||||
|
return str(content).strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def message_to_dict(*, msg: BaseMessage) -> Dict[str, Any]:
|
||||||
|
"""Convert a BaseMessage instance to a dictionary for streaming output.
|
||||||
|
|
||||||
|
This method normalizes BaseMessage instances into a consistent dictionary
|
||||||
|
format suitable for JSON serialization and streaming to clients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: The BaseMessage instance to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing:
|
||||||
|
- "role": The message role (user, assistant, system, tool)
|
||||||
|
- "content": The flattened message content as plain text
|
||||||
|
- "tool_calls": Tool calls if present (optional)
|
||||||
|
- "name": Message name if present (optional)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from langchain_core.messages import HumanMessage
|
||||||
|
>>> msg = HumanMessage(content="Hello there")
|
||||||
|
>>> result = ChatStreamingHelper.message_to_dict(msg=msg)
|
||||||
|
>>> result["role"]
|
||||||
|
'user'
|
||||||
|
>>> result["content"]
|
||||||
|
'Hello there'
|
||||||
|
"""
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"role": ChatStreamingHelper.role_from_message(msg=msg),
|
||||||
|
"content": ChatStreamingHelper.flatten_content(
|
||||||
|
content=getattr(msg, "content", "")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
tool_calls = getattr(msg, "tool_calls", None)
|
||||||
|
if tool_calls:
|
||||||
|
payload["tool_calls"] = tool_calls
|
||||||
|
name = getattr(msg, "name", None)
|
||||||
|
if name:
|
||||||
|
payload["name"] = name
|
||||||
|
return payload
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dict_message_to_dict(*, obj: Mapping[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Convert a dictionary-shaped message to a normalized dictionary.
|
||||||
|
|
||||||
|
This method handles messages that come from serialized state and are
|
||||||
|
represented as dictionaries rather than BaseMessage instances. It
|
||||||
|
normalizes various dictionary formats into a consistent structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The dictionary-shaped message to convert. Expected to contain
|
||||||
|
fields like "role", "type", "content", "text", etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A normalized dictionary containing:
|
||||||
|
- "role": The message role (user, assistant, system, tool)
|
||||||
|
- "content": The flattened message content as plain text
|
||||||
|
- "tool_calls": Tool calls if present (optional)
|
||||||
|
- "name": Message name if present (optional)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> obj = {"type": "human", "content": "Hello"}
|
||||||
|
>>> result = ChatStreamingHelper.dict_message_to_dict(obj=obj)
|
||||||
|
>>> result["role"]
|
||||||
|
'user'
|
||||||
|
>>> result["content"]
|
||||||
|
'Hello'
|
||||||
|
"""
|
||||||
|
role: Optional[str] = obj.get("role")
|
||||||
|
if not role:
|
||||||
|
# Handle alternative type field mappings
|
||||||
|
typ = obj.get("type")
|
||||||
|
if typ in ("human", "user"):
|
||||||
|
role = "user"
|
||||||
|
elif typ in ("ai", "assistant"):
|
||||||
|
role = "assistant"
|
||||||
|
elif typ in ("system",):
|
||||||
|
role = "system"
|
||||||
|
elif typ in ("tool", "function"):
|
||||||
|
role = "tool"
|
||||||
|
|
||||||
|
content = obj.get("content")
|
||||||
|
if content is None and "text" in obj:
|
||||||
|
content = obj["text"]
|
||||||
|
|
||||||
|
out: Dict[str, Any] = {
|
||||||
|
"role": role or "assistant",
|
||||||
|
"content": ChatStreamingHelper.flatten_content(content=content),
|
||||||
|
}
|
||||||
|
if "tool_calls" in obj:
|
||||||
|
out["tool_calls"] = obj["tool_calls"]
|
||||||
|
if obj.get("name"):
|
||||||
|
out["name"] = obj["name"]
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_messages_from_output(*, output_obj: Any) -> List[Any]:
|
||||||
|
"""Extract messages from LangGraph output objects.
|
||||||
|
|
||||||
|
This method handles various output formats from LangGraph execution,
|
||||||
|
extracting the messages list from different possible structures.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_obj: The output object from LangGraph execution. Can be:
|
||||||
|
- An object with a "messages" attribute
|
||||||
|
- A dictionary with a "messages" key
|
||||||
|
- Any other object (returns empty list)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of extracted messages, or an empty list if no messages
|
||||||
|
are found or if the output object is None.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> output = {"messages": [{"role": "user", "content": "Hello"}]}
|
||||||
|
>>> messages = ChatStreamingHelper.extract_messages_from_output(output_obj=output)
|
||||||
|
>>> len(messages)
|
||||||
|
1
|
||||||
|
"""
|
||||||
|
if output_obj is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Try to parse dicts first
|
||||||
|
if isinstance(output_obj, dict):
|
||||||
|
msgs = output_obj.get("messages")
|
||||||
|
return msgs if isinstance(msgs, list) else []
|
||||||
|
|
||||||
|
# Then try to get messages attribute
|
||||||
|
msgs = getattr(output_obj, "messages", None)
|
||||||
|
return msgs if isinstance(msgs, list) else []
|
||||||
215
modules/features/chatBot/service.py
Normal file
215
modules/features/chatBot/service.py
Normal file
|
|
@ -0,0 +1,215 @@
|
||||||
|
"""Service layer for chatbot functionality."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import AsyncIterator, List
|
||||||
|
|
||||||
|
from modules.features.chatBot.domain.chatbot import Chatbot, get_langchain_model
|
||||||
|
from modules.features.chatBot.utils.checkpointer import get_checkpointer
|
||||||
|
from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||||
|
from modules.features.chatBot.utils import permissions
|
||||||
|
from modules.datamodels.datamodelChatbot import MessageItem, ChatMessageResponse
|
||||||
|
from modules.datamodels.datamodelUam import User
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
from modules.shared.configuration import APP_CONFIG
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def post_message(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
message: str,
|
||||||
|
user: User,
|
||||||
|
) -> ChatMessageResponse:
|
||||||
|
"""Post a chat message to the chatbot and return the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The unique identifier for the chat thread.
|
||||||
|
message: The content of the chat message.
|
||||||
|
user: The current user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response containing the full chat message history and thread ID.
|
||||||
|
"""
|
||||||
|
logger.info(f"User {user.id} posted message to thread {thread_id}")
|
||||||
|
|
||||||
|
# Get user permissions
|
||||||
|
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
|
||||||
|
if not tool_ids:
|
||||||
|
raise ValueError("User does not have permission to use any chatbot tools")
|
||||||
|
|
||||||
|
model_name = permissions.get_chatbot_model(user_id=user.id)
|
||||||
|
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
||||||
|
|
||||||
|
# Get tools from registry
|
||||||
|
registry = get_registry()
|
||||||
|
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
||||||
|
|
||||||
|
# Get model and checkpointer
|
||||||
|
model = get_langchain_model(model_name=model_name)
|
||||||
|
checkpointer = get_checkpointer()
|
||||||
|
|
||||||
|
# Get context window size from config
|
||||||
|
context_window_size = int(
|
||||||
|
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create chatbot instance
|
||||||
|
chatbot = await Chatbot.create(
|
||||||
|
model=model,
|
||||||
|
memory=checkpointer,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
context_window_size=context_window_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send message to chatbot
|
||||||
|
response = await chatbot.chat(message=message, chat_id=thread_id)
|
||||||
|
|
||||||
|
# Parse the response to the correct format
|
||||||
|
messages = []
|
||||||
|
for msg in response:
|
||||||
|
# Determine the role of the message
|
||||||
|
if isinstance(msg, HumanMessage):
|
||||||
|
role = "user"
|
||||||
|
elif isinstance(msg, AIMessage):
|
||||||
|
role = "assistant"
|
||||||
|
else:
|
||||||
|
continue # Skip any other message types
|
||||||
|
|
||||||
|
# Skip messages that are structured content, such as tool calls
|
||||||
|
if not isinstance(msg.content, str):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Append message to chat history
|
||||||
|
item = MessageItem(
|
||||||
|
role=role,
|
||||||
|
content=msg.content.strip(),
|
||||||
|
timestamp=0.0, # TODO: Add proper timestamp handling
|
||||||
|
)
|
||||||
|
messages.append(item)
|
||||||
|
|
||||||
|
return ChatMessageResponse(thread_id=thread_id, messages=messages)
|
||||||
|
|
||||||
|
|
||||||
|
async def post_message_stream(
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
message: str,
|
||||||
|
user: User,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""Post a chat message to the chatbot and stream progress updates (SSE).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The unique identifier for the chat thread.
|
||||||
|
message: The content of the chat message.
|
||||||
|
user: The current user.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Server-Sent Events formatted strings containing status updates and final response.
|
||||||
|
"""
|
||||||
|
logger.info(f"User {user.id} streaming message to thread {thread_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get user permissions
|
||||||
|
tool_ids = permissions.get_chatbot_tools(user_id=user.id)
|
||||||
|
if not tool_ids:
|
||||||
|
yield (
|
||||||
|
"data: "
|
||||||
|
+ json.dumps(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"message": "User does not have permission to use any chatbot tools",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n\n"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
model_name = permissions.get_chatbot_model(user_id=user.id)
|
||||||
|
system_prompt = permissions.get_system_prompt(user_id=user.id)
|
||||||
|
|
||||||
|
# Get tools from registry
|
||||||
|
registry = get_registry()
|
||||||
|
tools = registry.get_tool_instances(tool_ids=tool_ids)
|
||||||
|
|
||||||
|
# Get model and checkpointer
|
||||||
|
model = get_langchain_model(model_name=model_name)
|
||||||
|
checkpointer = get_checkpointer()
|
||||||
|
|
||||||
|
# Get context window size from config
|
||||||
|
context_window_size = int(
|
||||||
|
APP_CONFIG.get("CHATBOT_CONTEXT_WINDOW_TOKEN_SIZE", 100000)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create chatbot instance
|
||||||
|
chatbot = await Chatbot.create(
|
||||||
|
model=model,
|
||||||
|
memory=checkpointer,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
context_window_size=context_window_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream events from chatbot
|
||||||
|
async for event in chatbot.stream_events(message=message, chat_id=thread_id):
|
||||||
|
etype = event.get("type")
|
||||||
|
|
||||||
|
# Forward status updates
|
||||||
|
if etype == "status":
|
||||||
|
yield f"data: {json.dumps({'type': 'status', 'label': event.get('label')})}\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Forward final response
|
||||||
|
if etype == "final":
|
||||||
|
response_from_event = event.get("response") or {}
|
||||||
|
|
||||||
|
# Use the chat history from the final event (already normalized by stream_events)
|
||||||
|
chat_history_payload = response_from_event.get("chat_history", [])
|
||||||
|
if isinstance(chat_history_payload, list):
|
||||||
|
# Convert to MessageItem format
|
||||||
|
items: List[MessageItem] = []
|
||||||
|
for it in chat_history_payload:
|
||||||
|
role = it.get("role")
|
||||||
|
content = it.get("content", "")
|
||||||
|
if role in ("user", "assistant") and content:
|
||||||
|
items.append(
|
||||||
|
MessageItem(
|
||||||
|
role=role,
|
||||||
|
content=content,
|
||||||
|
timestamp=0.0, # TODO: Add proper timestamp handling
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = ChatMessageResponse(thread_id=thread_id, messages=items)
|
||||||
|
# Yield the final response and exit
|
||||||
|
yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n"
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Unexpected payload format - log warning and return empty history
|
||||||
|
logger.warning(
|
||||||
|
f"Unexpected chat_history format in final event: {type(chat_history_payload)}"
|
||||||
|
)
|
||||||
|
response = ChatMessageResponse(thread_id=thread_id, messages=[])
|
||||||
|
yield f"data: {json.dumps({'type': 'final', 'response': response.model_dump()})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# Forward error events
|
||||||
|
if etype == "error":
|
||||||
|
yield f"data: {json.dumps(event)}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in streaming chat: {str(e)}", exc_info=True)
|
||||||
|
yield (
|
||||||
|
"data: "
|
||||||
|
+ json.dumps(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"message": "An error occurred while processing your request.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n\n"
|
||||||
|
)
|
||||||
95
modules/features/chatBot/utils/checkpointer.py
Normal file
95
modules/features/chatBot/utils/checkpointer.py
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
"""PostgreSQL checkpointer utilities for LangGraph memory."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from langgraph.checkpoint.postgres import PostgresSaver
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
from modules.shared.configuration import APP_CONFIG
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global checkpointer instance
|
||||||
|
_checkpointer_instance: Optional[PostgresSaver] = None
|
||||||
|
_connection_pool: Optional[AsyncConnectionPool] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_checkpointer() -> None:
|
||||||
|
"""Initialize the PostgreSQL checkpointer for LangGraph.
|
||||||
|
|
||||||
|
This should be called during application startup.
|
||||||
|
Creates a connection pool and PostgresSaver instance.
|
||||||
|
"""
|
||||||
|
global _checkpointer_instance, _connection_pool
|
||||||
|
|
||||||
|
if _checkpointer_instance is not None:
|
||||||
|
logger.info("Checkpointer already initialized")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get database configuration from environment
|
||||||
|
host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost")
|
||||||
|
database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat")
|
||||||
|
user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev")
|
||||||
|
password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET")
|
||||||
|
port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432")
|
||||||
|
|
||||||
|
# Build connection string
|
||||||
|
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}"
|
||||||
|
|
||||||
|
# Create async connection pool
|
||||||
|
_connection_pool = AsyncConnectionPool(
|
||||||
|
conninfo=connection_string,
|
||||||
|
min_size=2,
|
||||||
|
max_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the connection pool
|
||||||
|
await _connection_pool.open()
|
||||||
|
|
||||||
|
# Create PostgresSaver with the pool
|
||||||
|
_checkpointer_instance = PostgresSaver(_connection_pool)
|
||||||
|
|
||||||
|
# Setup the checkpointer (creates tables if needed)
|
||||||
|
await _checkpointer_instance.setup()
|
||||||
|
|
||||||
|
logger.info("PostgreSQL checkpointer initialized successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def close_checkpointer() -> None:
|
||||||
|
"""Close the checkpointer and connection pool.
|
||||||
|
|
||||||
|
This should be called during application shutdown.
|
||||||
|
"""
|
||||||
|
global _checkpointer_instance, _connection_pool
|
||||||
|
|
||||||
|
if _connection_pool is not None:
|
||||||
|
try:
|
||||||
|
await _connection_pool.close()
|
||||||
|
logger.info("PostgreSQL checkpointer connection pool closed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing checkpointer connection pool: {str(e)}")
|
||||||
|
|
||||||
|
_checkpointer_instance = None
|
||||||
|
_connection_pool = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpointer() -> PostgresSaver:
|
||||||
|
"""Get the global PostgreSQL checkpointer instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The initialized PostgresSaver instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If checkpointer is not initialized
|
||||||
|
"""
|
||||||
|
if _checkpointer_instance is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"PostgreSQL checkpointer not initialized. "
|
||||||
|
"Call initialize_checkpointer() during application startup."
|
||||||
|
)
|
||||||
|
return _checkpointer_instance
|
||||||
|
|
@ -12,36 +12,15 @@ from modules.features.chatBot.utils.toolRegistry import get_registry
|
||||||
# TODO: Replace these mock implementations with actual database queries
|
# TODO: Replace these mock implementations with actual database queries
|
||||||
|
|
||||||
|
|
||||||
def get_allowed_tools(*, user_id: str) -> list[str]:
|
def get_chatbot_tools(*, user_id: str) -> list[str]:
|
||||||
"""Get list of tool IDs that a user is allowed to use.
|
"""Get list of tool IDs that the chatbot can use for a given user."""
|
||||||
|
|
||||||
This is a mock implementation that returns all available tools
|
|
||||||
regardless of user_id. In production, this will query the database
|
|
||||||
for user-specific permissions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The unique identifier of the user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of tool IDs (e.g., ["shared.tavily_search", "customer.query_althaus_database"])
|
|
||||||
"""
|
|
||||||
registry = get_registry()
|
registry = get_registry()
|
||||||
return registry.list_tool_ids()
|
return registry.list_tool_ids()
|
||||||
|
|
||||||
|
|
||||||
def get_allowed_models(*, user_id: str) -> list[str]:
|
def get_chatbot_model(*, user_id: str) -> str:
|
||||||
"""Get list of AI models that a user is allowed to use.
|
"""Gets the chatbot model(s) a user is allowed to use."""
|
||||||
|
return "claude_4_5"
|
||||||
This is a mock implementation that returns a fixed list of models.
|
|
||||||
In production, this will query the database for user-specific model permissions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The unique identifier of the user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model identifiers (e.g., ["gpt-5", "claude-4-5"])
|
|
||||||
"""
|
|
||||||
return ["gpt-5", "claude-4-5"]
|
|
||||||
|
|
||||||
|
|
||||||
def get_system_prompt(*, user_id: str) -> str:
|
def get_system_prompt(*, user_id: str) -> str:
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,23 @@
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from modules.datamodels.datamodelUam import User
|
from modules.datamodels.datamodelUam import User
|
||||||
|
from modules.datamodels.datamodelChatbot import (
|
||||||
|
ChatMessageRequest,
|
||||||
|
MessageItem,
|
||||||
|
ChatMessageResponse,
|
||||||
|
ThreadSummary,
|
||||||
|
ThreadListResponse,
|
||||||
|
ThreadDetail,
|
||||||
|
DeleteResponse,
|
||||||
|
)
|
||||||
from modules.security.auth import getCurrentUser, limiter
|
from modules.security.auth import getCurrentUser, limiter
|
||||||
|
from modules.features.chatBot import service as chat_service
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -17,68 +27,53 @@ router = APIRouter(
|
||||||
responses={404: {"description": "Not found"}},
|
responses={404: {"description": "Not found"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Pydantic models for requests and responses ---
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageRequest(BaseModel):
|
|
||||||
"""Request model for posting a chat message"""
|
|
||||||
|
|
||||||
thread_id: Optional[str] = Field(
|
|
||||||
None, description="Thread ID (creates new thread if not provided)"
|
|
||||||
)
|
|
||||||
message: str = Field(..., description="User message content")
|
|
||||||
|
|
||||||
|
|
||||||
class MessageItem(BaseModel):
|
|
||||||
"""Individual message in a thread"""
|
|
||||||
|
|
||||||
role: str = Field(..., description="Message role (user or assistant)")
|
|
||||||
content: str = Field(..., description="Message content")
|
|
||||||
timestamp: float = Field(..., description="Message timestamp (Unix timestamp)")
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageResponse(BaseModel):
|
|
||||||
"""Response model for posting a chat message"""
|
|
||||||
|
|
||||||
thread_id: str = Field(..., description="Thread ID")
|
|
||||||
messages: List[MessageItem] = Field(..., description="All messages in thread")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadSummary(BaseModel):
|
|
||||||
"""Summary of a chat thread for list view"""
|
|
||||||
|
|
||||||
thread_id: str = Field(..., description="Thread ID")
|
|
||||||
created_at: float = Field(..., description="Thread creation timestamp")
|
|
||||||
last_message: str = Field(..., description="Last message content")
|
|
||||||
message_count: int = Field(..., description="Total number of messages")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadListResponse(BaseModel):
|
|
||||||
"""Response model for listing all threads"""
|
|
||||||
|
|
||||||
threads: List[ThreadSummary] = Field(..., description="List of thread summaries")
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadDetail(BaseModel):
|
|
||||||
"""Detailed view of a single thread"""
|
|
||||||
|
|
||||||
thread_id: str = Field(..., description="Thread ID")
|
|
||||||
created_at: float = Field(..., description="Thread creation timestamp")
|
|
||||||
messages: List[MessageItem] = Field(
|
|
||||||
..., description="All messages in chronological order"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteResponse(BaseModel):
|
|
||||||
"""Response model for delete operations"""
|
|
||||||
|
|
||||||
message: str = Field(..., description="Confirmation message")
|
|
||||||
thread_id: str = Field(..., description="Deleted thread ID")
|
|
||||||
|
|
||||||
|
|
||||||
# --- Actual endpoints for chatbot ---
|
# --- Actual endpoints for chatbot ---
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/message/stream")
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
async def post_chat_message_stream(
|
||||||
|
*,
|
||||||
|
request: Request,
|
||||||
|
message_request: ChatMessageRequest,
|
||||||
|
currentUser: User = Depends(getCurrentUser),
|
||||||
|
) -> StreamingResponse:
|
||||||
|
"""
|
||||||
|
Post a message to a chat thread with streaming progress updates.
|
||||||
|
Creates a new thread if thread_id is not provided.
|
||||||
|
|
||||||
|
Returns Server-Sent Events (SSE) stream with status updates and final response.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Generate or use existing thread_id
|
||||||
|
thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"User {currentUser.id} posted streaming message to thread {thread_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
chat_service.post_message_stream(
|
||||||
|
thread_id=thread_id,
|
||||||
|
message=message_request.message,
|
||||||
|
user=currentUser,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error posting chat message: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to post message: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/message", response_model=ChatMessageResponse)
|
@router.post("/message", response_model=ChatMessageResponse)
|
||||||
@limiter.limit("30/minute")
|
@limiter.limit("30/minute")
|
||||||
async def post_chat_message(
|
async def post_chat_message(
|
||||||
|
|
@ -88,35 +83,31 @@ async def post_chat_message(
|
||||||
currentUser: User = Depends(getCurrentUser),
|
currentUser: User = Depends(getCurrentUser),
|
||||||
) -> ChatMessageResponse:
|
) -> ChatMessageResponse:
|
||||||
"""
|
"""
|
||||||
Post a message to a chat thread and get assistant response.
|
Post a message to a chat thread and get assistant response (non-streaming).
|
||||||
Creates a new thread if thread_id is not provided.
|
Creates a new thread if thread_id is not provided.
|
||||||
|
|
||||||
This endpoint will later be connected to LangGraph's checkpointer.
|
For streaming updates, use the /message/stream endpoint instead.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Generate or use existing thread_id
|
# Generate or use existing thread_id
|
||||||
thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}"
|
thread_id = message_request.thread_id or f"thread_{uuid.uuid4()}"
|
||||||
|
|
||||||
# Get current timestamp
|
|
||||||
current_time = datetime.now().timestamp()
|
|
||||||
|
|
||||||
# Create dummy message history
|
|
||||||
# In production, this will fetch from LangGraph's checkpointer
|
|
||||||
messages = [
|
|
||||||
MessageItem(
|
|
||||||
role="user", content=message_request.message, timestamp=current_time
|
|
||||||
),
|
|
||||||
MessageItem(
|
|
||||||
role="assistant",
|
|
||||||
content=f"Echo: {message_request.message} (This is a dummy response. LangGraph integration pending.)",
|
|
||||||
timestamp=current_time + 0.5,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
|
logger.info(f"User {currentUser.id} posted message to thread {thread_id}")
|
||||||
|
|
||||||
return ChatMessageResponse(thread_id=thread_id, messages=messages)
|
response = await chat_service.post_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
message=message_request.message,
|
||||||
|
user=currentUser,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Permission error: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error posting chat message: {str(e)}")
|
logger.error(f"Error posting chat message: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ fastapi==0.104.1
|
||||||
websockets==12.0
|
websockets==12.0
|
||||||
uvicorn==0.23.2
|
uvicorn==0.23.2
|
||||||
python-multipart==0.0.6
|
python-multipart==0.0.6
|
||||||
httpx==0.25.0
|
httpx>=0.25.2
|
||||||
pydantic>=2.0.0 # Upgraded to v2 for LangChain compatibility
|
pydantic>=2.0.0 # Upgraded to v2 for LangChain compatibility
|
||||||
email-validator==2.0.0 # Required by Pydantic for email validation
|
email-validator==2.0.0 # Required by Pydantic for email validation
|
||||||
slowapi==0.1.8 # For rate limiting
|
slowapi==0.1.8 # For rate limiting
|
||||||
|
|
@ -113,3 +113,7 @@ psycopg2-binary==2.9.9
|
||||||
langchain==0.3.27
|
langchain==0.3.27
|
||||||
langgraph==0.6.8
|
langgraph==0.6.8
|
||||||
langchain-core==0.3.77
|
langchain-core==0.3.77
|
||||||
|
langchain-anthropic==0.3.1 # For Claude models
|
||||||
|
psycopg[binary]==3.2.1 # For PostgreSQL async support (LangGraph checkpointer)
|
||||||
|
psycopg-pool==3.2.1 # Connection pooling for PostgreSQL
|
||||||
|
langgraph-checkpoint-postgres==2.0.24
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue