187 lines
6.7 KiB
Python
187 lines
6.7 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Chatbot V2 checkpointer - maps LangGraph state to ChatbotV2 message storage.
|
|
"""
|
|
|
|
import logging
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional, Tuple, NamedTuple
|
|
|
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
|
|
from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointMetadata
|
|
|
|
|
|
class CheckpointTuple(NamedTuple):
|
|
"""Tuple containing config, checkpoint, metadata, parent_config, and pending_writes."""
|
|
config: Dict[str, Any]
|
|
checkpoint: Checkpoint
|
|
metadata: CheckpointMetadata
|
|
parent_config: Optional[Dict[str, Any]] = None
|
|
pending_writes: Optional[List[Tuple[str, Any]]] = None
|
|
|
|
from modules.features.chatbotV2.interfaceFeatureChatbotV2 import getInterface as getChatbotV2Interface
|
|
from modules.features.chatbotV2.datamodelFeatureChatbotV2 import ChatbotV2Message
|
|
from modules.datamodels.datamodelUam import User
|
|
from modules.shared.timeUtils import getUtcTimestamp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _sanitize_llm_response(text: str) -> str:
|
|
"""Strip chat template tokens and trailing junk."""
|
|
if not text or not isinstance(text, str):
|
|
return text or ""
|
|
for sentinel in ("<|im_start|>", "<|im_end|>", "<|endoftext|>", "<|user|>", "<|assistant|>"):
|
|
if sentinel in text:
|
|
text = text.split(sentinel)[0]
|
|
return text.strip()
|
|
|
|
|
|
class ChatbotV2Checkpointer(BaseCheckpointSaver):
|
|
"""Checkpointer that stores messages via ChatbotV2 interface."""
|
|
|
|
def __init__(
|
|
self,
|
|
user: User,
|
|
workflow_id: str,
|
|
mandateId: Optional[str] = None,
|
|
featureInstanceId: Optional[str] = None
|
|
):
|
|
self.user = user
|
|
self.workflow_id = workflow_id
|
|
self.interface = getChatbotV2Interface(
|
|
user,
|
|
mandateId=mandateId,
|
|
featureInstanceId=featureInstanceId
|
|
)
|
|
|
|
def _to_db_message(
|
|
self,
|
|
msg: BaseMessage,
|
|
sequence_nr: int,
|
|
round_number: int
|
|
) -> Dict[str, Any]:
|
|
role = "user"
|
|
content = ""
|
|
if isinstance(msg, HumanMessage):
|
|
role = "user"
|
|
content = msg.content if isinstance(msg.content, str) else str(msg.content or "")
|
|
elif isinstance(msg, AIMessage):
|
|
role = "assistant"
|
|
content = msg.content if isinstance(msg.content, str) else str(msg.content or "")
|
|
content = _sanitize_llm_response(content)
|
|
elif isinstance(msg, SystemMessage):
|
|
role = "system"
|
|
content = msg.content if isinstance(msg.content, str) else str(msg.content or "")
|
|
return {
|
|
"id": str(uuid.uuid4()),
|
|
"conversationId": self.workflow_id,
|
|
"message": content,
|
|
"role": role,
|
|
"status": "step" if sequence_nr > 1 else "first",
|
|
"sequenceNr": sequence_nr,
|
|
"publishedAt": getUtcTimestamp(),
|
|
"roundNumber": round_number
|
|
}
|
|
|
|
def _to_langchain(self, messages: List[ChatbotV2Message]) -> List[BaseMessage]:
|
|
result = []
|
|
for m in messages:
|
|
if m.role == "user":
|
|
result.append(HumanMessage(content=m.message or ""))
|
|
elif m.role == "assistant":
|
|
result.append(AIMessage(content=m.message or ""))
|
|
elif m.role == "system":
|
|
result.append(SystemMessage(content=m.message or ""))
|
|
return result
|
|
|
|
def put(
|
|
self,
|
|
config: Dict[str, Any],
|
|
checkpoint: Checkpoint,
|
|
metadata: CheckpointMetadata,
|
|
new_versions: Dict[str, int],
|
|
) -> None:
|
|
thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id)
|
|
conv = self.interface.getConversation(thread_id)
|
|
if not conv:
|
|
logger.warning(f"Conversation {thread_id} not found")
|
|
return
|
|
round_number = conv.currentRound or 1
|
|
state = checkpoint.get("channel_values", {})
|
|
messages = state.get("messages", [])
|
|
if not messages:
|
|
return
|
|
existing = self.interface.getMessages(thread_id)
|
|
existing_set = {(m.role, m.message) for m in existing}
|
|
existing_count = len(existing)
|
|
for i, msg in enumerate(messages):
|
|
if not isinstance(msg, (HumanMessage, AIMessage)):
|
|
continue
|
|
role = "user" if isinstance(msg, HumanMessage) else "assistant"
|
|
content = msg.content if isinstance(msg.content, str) else str(msg.content or "")
|
|
if isinstance(msg, AIMessage):
|
|
content = _sanitize_llm_response(content)
|
|
if not content or not content.strip():
|
|
continue
|
|
if (role, content) in existing_set:
|
|
continue
|
|
existing_set.add((role, content))
|
|
existing_count += 1
|
|
db_msg = self._to_db_message(msg, existing_count, round_number)
|
|
self.interface.createMessage(db_msg)
|
|
self.interface.updateConversation(thread_id, {"lastActivity": getUtcTimestamp()})
|
|
|
|
def get(self, config: Dict[str, Any]) -> Optional[Checkpoint]:
|
|
thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id)
|
|
conv = self.interface.getConversation(thread_id)
|
|
if not conv:
|
|
return None
|
|
messages = self.interface.getMessages(thread_id)
|
|
lc_messages = self._to_langchain(messages)
|
|
return {
|
|
"id": str(uuid.uuid4()),
|
|
"v": 1,
|
|
"ts": getUtcTimestamp(),
|
|
"channel_values": {"messages": lc_messages},
|
|
"channel_versions": {},
|
|
"versions_seen": {}
|
|
}
|
|
|
|
# Async methods required for LangGraph ainvoke/astream
|
|
async def aget_tuple(
|
|
self,
|
|
config: Dict[str, Any],
|
|
) -> Optional[CheckpointTuple]:
|
|
"""Async version of get that returns tuple of (config, checkpoint, metadata)."""
|
|
checkpoint = self.get(config)
|
|
if checkpoint:
|
|
metadata: CheckpointMetadata = {"step": 0}
|
|
return CheckpointTuple(
|
|
config=config,
|
|
checkpoint=checkpoint,
|
|
metadata=metadata,
|
|
parent_config=None,
|
|
pending_writes=None,
|
|
)
|
|
return None
|
|
|
|
async def aput(
|
|
self,
|
|
config: Dict[str, Any],
|
|
checkpoint: Checkpoint,
|
|
metadata: CheckpointMetadata,
|
|
new_versions: Dict[str, int],
|
|
) -> None:
|
|
"""Async version of put."""
|
|
self.put(config, checkpoint, metadata, new_versions)
|
|
|
|
async def aput_writes(
|
|
self,
|
|
config: Dict[str, Any],
|
|
writes: List[Tuple[str, Any]],
|
|
task_id: str,
|
|
) -> None:
|
|
"""Async version of put_writes. No-op - writes are handled through aput()."""
|
|
pass
|