# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Chatbot V2 checkpointer - maps LangGraph state to ChatbotV2 message storage. """ import logging import uuid from typing import Any, Dict, List, Optional, Tuple, NamedTuple from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointMetadata class CheckpointTuple(NamedTuple): """Tuple containing config, checkpoint, metadata, parent_config, and pending_writes.""" config: Dict[str, Any] checkpoint: Checkpoint metadata: CheckpointMetadata parent_config: Optional[Dict[str, Any]] = None pending_writes: Optional[List[Tuple[str, Any]]] = None from modules.features.chatbotV2.interfaceFeatureChatbotV2 import getInterface as getChatbotV2Interface from modules.features.chatbotV2.datamodelFeatureChatbotV2 import ChatbotV2Message from modules.datamodels.datamodelUam import User from modules.shared.timeUtils import getUtcTimestamp logger = logging.getLogger(__name__) def _sanitize_llm_response(text: str) -> str: """Strip chat template tokens and trailing junk.""" if not text or not isinstance(text, str): return text or "" for sentinel in ("<|im_start|>", "<|im_end|>", "<|endoftext|>", "<|user|>", "<|assistant|>"): if sentinel in text: text = text.split(sentinel)[0] return text.strip() class ChatbotV2Checkpointer(BaseCheckpointSaver): """Checkpointer that stores messages via ChatbotV2 interface.""" def __init__( self, user: User, workflow_id: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None ): self.user = user self.workflow_id = workflow_id self.interface = getChatbotV2Interface( user, mandateId=mandateId, featureInstanceId=featureInstanceId ) def _to_db_message( self, msg: BaseMessage, sequence_nr: int, round_number: int ) -> Dict[str, Any]: role = "user" content = "" if isinstance(msg, HumanMessage): role = "user" content = msg.content if isinstance(msg.content, str) else str(msg.content or "") elif isinstance(msg, AIMessage): role = "assistant" content = msg.content if isinstance(msg.content, str) else str(msg.content or "") content = _sanitize_llm_response(content) elif isinstance(msg, SystemMessage): role = "system" content = msg.content if isinstance(msg.content, str) else str(msg.content or "") return { "id": str(uuid.uuid4()), "conversationId": self.workflow_id, "message": content, "role": role, "status": "step" if sequence_nr > 1 else "first", "sequenceNr": sequence_nr, "publishedAt": getUtcTimestamp(), "roundNumber": round_number } def _to_langchain(self, messages: List[ChatbotV2Message]) -> List[BaseMessage]: result = [] for m in messages: if m.role == "user": result.append(HumanMessage(content=m.message or "")) elif m.role == "assistant": result.append(AIMessage(content=m.message or "")) elif m.role == "system": result.append(SystemMessage(content=m.message or "")) return result def put( self, config: Dict[str, Any], checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: Dict[str, int], ) -> None: thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id) conv = self.interface.getConversation(thread_id) if not conv: logger.warning(f"Conversation {thread_id} not found") return round_number = conv.currentRound or 1 state = checkpoint.get("channel_values", {}) messages = state.get("messages", []) if not messages: return existing = self.interface.getMessages(thread_id) existing_set = {(m.role, m.message) for m in existing} existing_count = len(existing) for i, msg in enumerate(messages): if not isinstance(msg, (HumanMessage, AIMessage)): continue role = "user" if isinstance(msg, HumanMessage) else "assistant" content = msg.content if isinstance(msg.content, str) else str(msg.content or "") if isinstance(msg, AIMessage): content = _sanitize_llm_response(content) if not content or not content.strip(): continue if (role, content) in existing_set: continue existing_set.add((role, content)) existing_count += 1 db_msg = self._to_db_message(msg, existing_count, round_number) self.interface.createMessage(db_msg) self.interface.updateConversation(thread_id, {"lastActivity": getUtcTimestamp()}) def get(self, config: Dict[str, Any]) -> Optional[Checkpoint]: thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id) conv = self.interface.getConversation(thread_id) if not conv: return None messages = self.interface.getMessages(thread_id) lc_messages = self._to_langchain(messages) return { "id": str(uuid.uuid4()), "v": 1, "ts": getUtcTimestamp(), "channel_values": {"messages": lc_messages}, "channel_versions": {}, "versions_seen": {} } # Async methods required for LangGraph ainvoke/astream async def aget_tuple( self, config: Dict[str, Any], ) -> Optional[CheckpointTuple]: """Async version of get that returns tuple of (config, checkpoint, metadata).""" checkpoint = self.get(config) if checkpoint: metadata: CheckpointMetadata = {"step": 0} return CheckpointTuple( config=config, checkpoint=checkpoint, metadata=metadata, parent_config=None, pending_writes=None, ) return None async def aput( self, config: Dict[str, Any], checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: Dict[str, int], ) -> None: """Async version of put.""" self.put(config, checkpoint, metadata, new_versions) async def aput_writes( self, config: Dict[str, Any], writes: List[Tuple[str, Any]], task_id: str, ) -> None: """Async version of put_writes. No-op - writes are handled through aput().""" pass