gateway/modules/features/chatbotV2/bridges/memory.py

187 lines
6.7 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Chatbot V2 checkpointer - maps LangGraph state to ChatbotV2 message storage.
"""
import logging
import uuid
from typing import Any, Dict, List, Optional, Tuple, NamedTuple
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointMetadata
class CheckpointTuple(NamedTuple):
"""Tuple containing config, checkpoint, metadata, parent_config, and pending_writes."""
config: Dict[str, Any]
checkpoint: Checkpoint
metadata: CheckpointMetadata
parent_config: Optional[Dict[str, Any]] = None
pending_writes: Optional[List[Tuple[str, Any]]] = None
from modules.features.chatbotV2.interfaceFeatureChatbotV2 import getInterface as getChatbotV2Interface
from modules.features.chatbotV2.datamodelFeatureChatbotV2 import ChatbotV2Message
from modules.datamodels.datamodelUam import User
from modules.shared.timeUtils import getUtcTimestamp
logger = logging.getLogger(__name__)
def _sanitize_llm_response(text: str) -> str:
"""Strip chat template tokens and trailing junk."""
if not text or not isinstance(text, str):
return text or ""
for sentinel in ("<|im_start|>", "<|im_end|>", "<|endoftext|>", "<|user|>", "<|assistant|>"):
if sentinel in text:
text = text.split(sentinel)[0]
return text.strip()
class ChatbotV2Checkpointer(BaseCheckpointSaver):
"""Checkpointer that stores messages via ChatbotV2 interface."""
def __init__(
self,
user: User,
workflow_id: str,
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None
):
self.user = user
self.workflow_id = workflow_id
self.interface = getChatbotV2Interface(
user,
mandateId=mandateId,
featureInstanceId=featureInstanceId
)
def _to_db_message(
self,
msg: BaseMessage,
sequence_nr: int,
round_number: int
) -> Dict[str, Any]:
role = "user"
content = ""
if isinstance(msg, HumanMessage):
role = "user"
content = msg.content if isinstance(msg.content, str) else str(msg.content or "")
elif isinstance(msg, AIMessage):
role = "assistant"
content = msg.content if isinstance(msg.content, str) else str(msg.content or "")
content = _sanitize_llm_response(content)
elif isinstance(msg, SystemMessage):
role = "system"
content = msg.content if isinstance(msg.content, str) else str(msg.content or "")
return {
"id": str(uuid.uuid4()),
"conversationId": self.workflow_id,
"message": content,
"role": role,
"status": "step" if sequence_nr > 1 else "first",
"sequenceNr": sequence_nr,
"publishedAt": getUtcTimestamp(),
"roundNumber": round_number
}
def _to_langchain(self, messages: List[ChatbotV2Message]) -> List[BaseMessage]:
result = []
for m in messages:
if m.role == "user":
result.append(HumanMessage(content=m.message or ""))
elif m.role == "assistant":
result.append(AIMessage(content=m.message or ""))
elif m.role == "system":
result.append(SystemMessage(content=m.message or ""))
return result
def put(
self,
config: Dict[str, Any],
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: Dict[str, int],
) -> None:
thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id)
conv = self.interface.getConversation(thread_id)
if not conv:
logger.warning(f"Conversation {thread_id} not found")
return
round_number = conv.currentRound or 1
state = checkpoint.get("channel_values", {})
messages = state.get("messages", [])
if not messages:
return
existing = self.interface.getMessages(thread_id)
existing_set = {(m.role, m.message) for m in existing}
existing_count = len(existing)
for i, msg in enumerate(messages):
if not isinstance(msg, (HumanMessage, AIMessage)):
continue
role = "user" if isinstance(msg, HumanMessage) else "assistant"
content = msg.content if isinstance(msg.content, str) else str(msg.content or "")
if isinstance(msg, AIMessage):
content = _sanitize_llm_response(content)
if not content or not content.strip():
continue
if (role, content) in existing_set:
continue
existing_set.add((role, content))
existing_count += 1
db_msg = self._to_db_message(msg, existing_count, round_number)
self.interface.createMessage(db_msg)
self.interface.updateConversation(thread_id, {"lastActivity": getUtcTimestamp()})
def get(self, config: Dict[str, Any]) -> Optional[Checkpoint]:
thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id)
conv = self.interface.getConversation(thread_id)
if not conv:
return None
messages = self.interface.getMessages(thread_id)
lc_messages = self._to_langchain(messages)
return {
"id": str(uuid.uuid4()),
"v": 1,
"ts": getUtcTimestamp(),
"channel_values": {"messages": lc_messages},
"channel_versions": {},
"versions_seen": {}
}
# Async methods required for LangGraph ainvoke/astream
async def aget_tuple(
self,
config: Dict[str, Any],
) -> Optional[CheckpointTuple]:
"""Async version of get that returns tuple of (config, checkpoint, metadata)."""
checkpoint = self.get(config)
if checkpoint:
metadata: CheckpointMetadata = {"step": 0}
return CheckpointTuple(
config=config,
checkpoint=checkpoint,
metadata=metadata,
parent_config=None,
pending_writes=None,
)
return None
async def aput(
self,
config: Dict[str, Any],
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: Dict[str, int],
) -> None:
"""Async version of put."""
self.put(config, checkpoint, metadata, new_versions)
async def aput_writes(
self,
config: Dict[str, Any],
writes: List[Tuple[str, Any]],
task_id: str,
) -> None:
"""Async version of put_writes. No-op - writes are handled through aput()."""
pass