gateway/modules/features/chatbot/bridges/memory.py

432 lines
16 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Custom LangGraph checkpointer using existing database interface.
Maps LangGraph state to existing message storage format.
"""
import logging
import uuid
from typing import Any, Dict, List, Optional, Tuple, NamedTuple
from datetime import datetime
from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointMetadata
# CheckpointTuple might not be directly importable, so we define it as a NamedTuple
# Based on LangGraph's usage, it needs config, checkpoint, metadata, parent_config, and pending_writes
class CheckpointTuple(NamedTuple):
"""Tuple containing config, checkpoint, metadata, parent_config, and pending_writes."""
config: Dict[str, Any]
checkpoint: Checkpoint
metadata: CheckpointMetadata
parent_config: Optional[Dict[str, Any]] = None
pending_writes: Optional[List[Tuple[str, Any]]] = None
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
from modules.interfaces.interfaceDbChat import getInterface
from modules.datamodels.datamodelChat import ChatMessage, ChatWorkflow
from modules.datamodels.datamodelUam import User
from modules.shared.timeUtils import getUtcTimestamp
logger = logging.getLogger(__name__)
class DatabaseCheckpointer(BaseCheckpointSaver):
"""
Custom LangGraph checkpointer that uses the existing database interface.
Maps LangGraph thread_id to workflow.id and stores messages in the existing format.
"""
def __init__(self, user: User, workflow_id: str):
"""
Initialize the database checkpointer.
Args:
user: Current user for database access
workflow_id: Workflow ID (maps to LangGraph thread_id)
"""
self.user = user
self.workflow_id = workflow_id
self.interface = getInterface(user)
def _convert_langchain_to_db_message(
self,
msg: BaseMessage,
sequence_nr: int,
round_number: int
) -> Dict[str, Any]:
"""
Convert LangChain message to database message format.
Args:
msg: LangChain message
sequence_nr: Sequence number for ordering
round_number: Round number in workflow
Returns:
Dictionary in database message format
"""
import uuid
role = "user"
content = ""
if isinstance(msg, HumanMessage):
role = "user"
content = msg.content if isinstance(msg.content, str) else str(msg.content)
elif isinstance(msg, AIMessage):
role = "assistant"
content = msg.content if isinstance(msg.content, str) else str(msg.content)
elif isinstance(msg, SystemMessage):
# System messages are stored but marked as system
role = "system"
content = msg.content if isinstance(msg.content, str) else str(msg.content)
elif isinstance(msg, ToolMessage):
# Tool messages are stored as assistant messages with tool info
role = "assistant"
content = f"Tool {msg.name}: {msg.content}"
return {
"id": f"msg_{uuid.uuid4()}",
"workflowId": self.workflow_id,
"message": content,
"role": role,
"status": "step" if sequence_nr > 1 else "first",
"sequenceNr": sequence_nr,
"publishedAt": getUtcTimestamp(),
"roundNumber": round_number,
"taskNumber": 0,
"actionNumber": 0
}
def _convert_db_to_langchain_messages(
self,
messages: List[ChatMessage]
) -> List[BaseMessage]:
"""
Convert database messages to LangChain messages.
Args:
messages: List of database ChatMessage objects
Returns:
List of LangChain BaseMessage objects
"""
langchain_messages = []
for msg in messages:
if msg.role == "user":
langchain_messages.append(HumanMessage(content=msg.message))
elif msg.role == "assistant":
langchain_messages.append(AIMessage(content=msg.message))
elif msg.role == "system":
langchain_messages.append(SystemMessage(content=msg.message))
# Skip other roles for now
return langchain_messages
def put(
self,
config: Dict[str, Any],
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: Dict[str, int],
) -> None:
"""
Store a checkpoint in the database.
Args:
config: LangGraph config (contains thread_id)
checkpoint: Checkpoint to store
metadata: Checkpoint metadata
new_versions: New version numbers
"""
try:
# Extract thread_id from config (maps to workflow_id)
thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id)
# Get current workflow to determine round number
workflow = self.interface.getWorkflow(thread_id)
if not workflow:
logger.warning(f"Workflow {thread_id} not found, cannot store checkpoint")
return
round_number = workflow.currentRound if workflow else 1
# Extract messages from checkpoint
state = checkpoint.get("channel_values", {})
messages = state.get("messages", [])
if not messages:
logger.debug(f"No messages in checkpoint for workflow {thread_id}")
return
# Get existing messages to determine what's already stored
existing_messages = self.interface.getMessages(thread_id)
existing_count = len(existing_messages) if existing_messages else 0
# Create a set of existing message content+role for quick lookup
existing_content_set = set()
if existing_messages:
for existing_msg in existing_messages:
# Create a unique key from role and message content
content_key = (existing_msg.role, existing_msg.message)
existing_content_set.add(content_key)
# Filter checkpoint messages to only user/assistant (skip system)
# Skip intermediate AIMessages with tool_calls (these are tool call requests, not final answers)
checkpoint_user_assistant_messages = []
for msg in messages:
if isinstance(msg, HumanMessage):
# Always store user messages
checkpoint_user_assistant_messages.append(msg)
elif isinstance(msg, AIMessage):
# Check if this message has tool_calls
tool_calls = getattr(msg, "tool_calls", None)
# Skip messages with tool_calls - these are intermediate tool call requests
if tool_calls and len(tool_calls) > 0:
logger.debug(f"Skipping intermediate AIMessage with tool_calls for workflow {thread_id}")
continue
# Store all other AIMessages (final answers)
checkpoint_user_assistant_messages.append(msg)
# Only store new messages that aren't already in the database
new_messages_to_store = []
for msg in checkpoint_user_assistant_messages:
# Determine role
role = "user" if isinstance(msg, HumanMessage) else "assistant"
content = msg.content if isinstance(msg.content, str) else str(msg.content)
# Skip empty messages (they might be status updates)
if not content or not content.strip():
continue
# Check if this message already exists
content_key = (role, content)
if content_key not in existing_content_set:
new_messages_to_store.append(msg)
existing_content_set.add(content_key) # Mark as seen to avoid duplicates in this batch
# Store only the new messages
if new_messages_to_store:
for i, msg in enumerate(new_messages_to_store, 1):
sequence_nr = existing_count + i
# Convert to database format
db_message_data = self._convert_langchain_to_db_message(
msg,
sequence_nr,
round_number
)
# Store the message
try:
self.interface.createMessage(db_message_data)
logger.debug(f"Stored message {db_message_data['id']} for workflow {thread_id}")
existing_count += 1 # Update count for next message
except Exception as e:
logger.error(f"Error storing message: {e}", exc_info=True)
else:
logger.debug(f"No new messages to store for workflow {thread_id} (existing: {existing_count}, checkpoint: {len(checkpoint_user_assistant_messages)})")
# Update workflow last activity
self.interface.updateWorkflow(thread_id, {
"lastActivity": getUtcTimestamp()
})
except Exception as e:
logger.error(f"Error storing checkpoint: {e}", exc_info=True)
raise
def get(
self,
config: Dict[str, Any],
) -> Optional[Checkpoint]:
"""
Retrieve a checkpoint from the database.
Args:
config: LangGraph config (contains thread_id)
Returns:
Checkpoint if found, None otherwise
"""
try:
# Extract thread_id from config (maps to workflow_id)
thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id)
# Get workflow
workflow = self.interface.getWorkflow(thread_id)
if not workflow:
logger.debug(f"Workflow {thread_id} not found")
return None
# Get messages
messages = self.interface.getMessages(thread_id)
checkpoint_id = str(uuid.uuid4())
if not messages:
# Return empty checkpoint for new workflow
return {
"id": checkpoint_id,
"v": 1,
"ts": getUtcTimestamp(),
"channel_values": {
"messages": []
},
"channel_versions": {},
"versions_seen": {}
}
# Convert to LangChain messages
langchain_messages = self._convert_db_to_langchain_messages(messages)
# Build checkpoint
checkpoint = {
"id": checkpoint_id,
"v": 1,
"ts": getUtcTimestamp(),
"channel_values": {
"messages": langchain_messages
},
"channel_versions": {},
"versions_seen": {}
}
return checkpoint
except Exception as e:
logger.error(f"Error retrieving checkpoint: {e}", exc_info=True)
return None
def list(
self,
config: Dict[str, Any],
filter: Optional[Dict[str, Any]] = None,
before: Optional[str] = None,
limit: Optional[int] = None,
) -> List[Checkpoint]:
"""
List checkpoints (not fully implemented - returns current checkpoint).
Args:
config: LangGraph config
filter: Optional filter
before: Optional timestamp before which to list
limit: Optional limit on number of results
Returns:
List of checkpoints
"""
checkpoint = self.get(config)
if checkpoint:
return [checkpoint]
return []
def put_writes(
self,
config: Dict[str, Any],
writes: List[Tuple[str, Any]],
task_id: str,
) -> None:
"""
Store checkpoint writes (not used in current implementation).
Args:
config: LangGraph config
writes: List of write operations
task_id: Task ID
"""
# Not implemented - using put() instead
pass
async def aget_tuple(
self,
config: Dict[str, Any],
) -> Optional[CheckpointTuple]:
"""
Async version of get that returns tuple of (config, checkpoint, metadata).
Args:
config: LangGraph config (contains thread_id)
Returns:
CheckpointTuple with config, checkpoint and metadata if found, None otherwise
"""
checkpoint = self.get(config)
if checkpoint:
# Return checkpoint with metadata including step
# CheckpointMetadata is typically a TypedDict
# LangGraph expects 'step' in metadata
metadata: CheckpointMetadata = {
"step": 0 # Start at step 0, LangGraph will increment
}
return CheckpointTuple(
config=config,
checkpoint=checkpoint,
metadata=metadata,
parent_config=None, # No parent checkpoint for our implementation
pending_writes=None # No pending writes in our implementation
)
return None
async def aput(
self,
config: Dict[str, Any],
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: Dict[str, int],
) -> None:
"""
Async version of put.
Args:
config: LangGraph config (contains thread_id)
checkpoint: Checkpoint to store
metadata: Checkpoint metadata
new_versions: New version numbers
"""
self.put(config, checkpoint, metadata, new_versions)
async def alist(
self,
config: Dict[str, Any],
filter: Optional[Dict[str, Any]] = None,
before: Optional[str] = None,
limit: Optional[int] = None,
) -> List[Checkpoint]:
"""
Async version of list.
Args:
config: LangGraph config
filter: Optional filter
before: Optional timestamp before which to list
limit: Optional limit on number of results
Returns:
List of checkpoints
"""
return self.list(config, filter, before, limit)
async def aput_writes(
self,
config: Dict[str, Any],
writes: List[Tuple[str, Any]],
task_id: str,
) -> None:
"""
Async version of put_writes.
Store checkpoint writes (not used in current implementation).
Args:
config: LangGraph config
writes: List of write operations
task_id: Task ID
"""
# Not implemented - using aput() instead
# This method is called by LangGraph but we handle writes through aput()
pass