432 lines
16 KiB
Python
432 lines
16 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Custom LangGraph checkpointer using existing database interface.
|
|
Maps LangGraph state to existing message storage format.
|
|
"""
|
|
|
|
import logging
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional, Tuple, NamedTuple
|
|
from datetime import datetime
|
|
|
|
from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointMetadata
|
|
|
|
# CheckpointTuple might not be directly importable, so we define it as a NamedTuple
|
|
# Based on LangGraph's usage, it needs config, checkpoint, metadata, parent_config, and pending_writes
|
|
class CheckpointTuple(NamedTuple):
|
|
"""Tuple containing config, checkpoint, metadata, parent_config, and pending_writes."""
|
|
config: Dict[str, Any]
|
|
checkpoint: Checkpoint
|
|
metadata: CheckpointMetadata
|
|
parent_config: Optional[Dict[str, Any]] = None
|
|
pending_writes: Optional[List[Tuple[str, Any]]] = None
|
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
|
|
|
|
from modules.interfaces.interfaceDbChat import getInterface
|
|
from modules.datamodels.datamodelChat import ChatMessage, ChatWorkflow
|
|
from modules.datamodels.datamodelUam import User
|
|
from modules.shared.timeUtils import getUtcTimestamp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatabaseCheckpointer(BaseCheckpointSaver):
|
|
"""
|
|
Custom LangGraph checkpointer that uses the existing database interface.
|
|
Maps LangGraph thread_id to workflow.id and stores messages in the existing format.
|
|
"""
|
|
|
|
def __init__(self, user: User, workflow_id: str):
|
|
"""
|
|
Initialize the database checkpointer.
|
|
|
|
Args:
|
|
user: Current user for database access
|
|
workflow_id: Workflow ID (maps to LangGraph thread_id)
|
|
"""
|
|
self.user = user
|
|
self.workflow_id = workflow_id
|
|
self.interface = getInterface(user)
|
|
|
|
def _convert_langchain_to_db_message(
|
|
self,
|
|
msg: BaseMessage,
|
|
sequence_nr: int,
|
|
round_number: int
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Convert LangChain message to database message format.
|
|
|
|
Args:
|
|
msg: LangChain message
|
|
sequence_nr: Sequence number for ordering
|
|
round_number: Round number in workflow
|
|
|
|
Returns:
|
|
Dictionary in database message format
|
|
"""
|
|
import uuid
|
|
|
|
role = "user"
|
|
content = ""
|
|
|
|
if isinstance(msg, HumanMessage):
|
|
role = "user"
|
|
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
|
elif isinstance(msg, AIMessage):
|
|
role = "assistant"
|
|
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
|
elif isinstance(msg, SystemMessage):
|
|
# System messages are stored but marked as system
|
|
role = "system"
|
|
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
|
elif isinstance(msg, ToolMessage):
|
|
# Tool messages are stored as assistant messages with tool info
|
|
role = "assistant"
|
|
content = f"Tool {msg.name}: {msg.content}"
|
|
|
|
return {
|
|
"id": f"msg_{uuid.uuid4()}",
|
|
"workflowId": self.workflow_id,
|
|
"message": content,
|
|
"role": role,
|
|
"status": "step" if sequence_nr > 1 else "first",
|
|
"sequenceNr": sequence_nr,
|
|
"publishedAt": getUtcTimestamp(),
|
|
"roundNumber": round_number,
|
|
"taskNumber": 0,
|
|
"actionNumber": 0
|
|
}
|
|
|
|
def _convert_db_to_langchain_messages(
|
|
self,
|
|
messages: List[ChatMessage]
|
|
) -> List[BaseMessage]:
|
|
"""
|
|
Convert database messages to LangChain messages.
|
|
|
|
Args:
|
|
messages: List of database ChatMessage objects
|
|
|
|
Returns:
|
|
List of LangChain BaseMessage objects
|
|
"""
|
|
langchain_messages = []
|
|
|
|
for msg in messages:
|
|
if msg.role == "user":
|
|
langchain_messages.append(HumanMessage(content=msg.message))
|
|
elif msg.role == "assistant":
|
|
langchain_messages.append(AIMessage(content=msg.message))
|
|
elif msg.role == "system":
|
|
langchain_messages.append(SystemMessage(content=msg.message))
|
|
# Skip other roles for now
|
|
|
|
return langchain_messages
|
|
|
|
def put(
|
|
self,
|
|
config: Dict[str, Any],
|
|
checkpoint: Checkpoint,
|
|
metadata: CheckpointMetadata,
|
|
new_versions: Dict[str, int],
|
|
) -> None:
|
|
"""
|
|
Store a checkpoint in the database.
|
|
|
|
Args:
|
|
config: LangGraph config (contains thread_id)
|
|
checkpoint: Checkpoint to store
|
|
metadata: Checkpoint metadata
|
|
new_versions: New version numbers
|
|
"""
|
|
try:
|
|
# Extract thread_id from config (maps to workflow_id)
|
|
thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id)
|
|
|
|
# Get current workflow to determine round number
|
|
workflow = self.interface.getWorkflow(thread_id)
|
|
if not workflow:
|
|
logger.warning(f"Workflow {thread_id} not found, cannot store checkpoint")
|
|
return
|
|
|
|
round_number = workflow.currentRound if workflow else 1
|
|
|
|
# Extract messages from checkpoint
|
|
state = checkpoint.get("channel_values", {})
|
|
messages = state.get("messages", [])
|
|
|
|
if not messages:
|
|
logger.debug(f"No messages in checkpoint for workflow {thread_id}")
|
|
return
|
|
|
|
# Get existing messages to determine what's already stored
|
|
existing_messages = self.interface.getMessages(thread_id)
|
|
existing_count = len(existing_messages) if existing_messages else 0
|
|
|
|
# Create a set of existing message content+role for quick lookup
|
|
existing_content_set = set()
|
|
if existing_messages:
|
|
for existing_msg in existing_messages:
|
|
# Create a unique key from role and message content
|
|
content_key = (existing_msg.role, existing_msg.message)
|
|
existing_content_set.add(content_key)
|
|
|
|
# Filter checkpoint messages to only user/assistant (skip system)
|
|
# Skip intermediate AIMessages with tool_calls (these are tool call requests, not final answers)
|
|
checkpoint_user_assistant_messages = []
|
|
for msg in messages:
|
|
if isinstance(msg, HumanMessage):
|
|
# Always store user messages
|
|
checkpoint_user_assistant_messages.append(msg)
|
|
elif isinstance(msg, AIMessage):
|
|
# Check if this message has tool_calls
|
|
tool_calls = getattr(msg, "tool_calls", None)
|
|
|
|
# Skip messages with tool_calls - these are intermediate tool call requests
|
|
if tool_calls and len(tool_calls) > 0:
|
|
logger.debug(f"Skipping intermediate AIMessage with tool_calls for workflow {thread_id}")
|
|
continue
|
|
|
|
# Store all other AIMessages (final answers)
|
|
checkpoint_user_assistant_messages.append(msg)
|
|
|
|
# Only store new messages that aren't already in the database
|
|
new_messages_to_store = []
|
|
for msg in checkpoint_user_assistant_messages:
|
|
# Determine role
|
|
role = "user" if isinstance(msg, HumanMessage) else "assistant"
|
|
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
|
|
|
# Skip empty messages (they might be status updates)
|
|
if not content or not content.strip():
|
|
continue
|
|
|
|
# Check if this message already exists
|
|
content_key = (role, content)
|
|
if content_key not in existing_content_set:
|
|
new_messages_to_store.append(msg)
|
|
existing_content_set.add(content_key) # Mark as seen to avoid duplicates in this batch
|
|
|
|
# Store only the new messages
|
|
if new_messages_to_store:
|
|
for i, msg in enumerate(new_messages_to_store, 1):
|
|
sequence_nr = existing_count + i
|
|
|
|
# Convert to database format
|
|
db_message_data = self._convert_langchain_to_db_message(
|
|
msg,
|
|
sequence_nr,
|
|
round_number
|
|
)
|
|
|
|
# Store the message
|
|
try:
|
|
self.interface.createMessage(db_message_data)
|
|
logger.debug(f"Stored message {db_message_data['id']} for workflow {thread_id}")
|
|
existing_count += 1 # Update count for next message
|
|
except Exception as e:
|
|
logger.error(f"Error storing message: {e}", exc_info=True)
|
|
else:
|
|
logger.debug(f"No new messages to store for workflow {thread_id} (existing: {existing_count}, checkpoint: {len(checkpoint_user_assistant_messages)})")
|
|
|
|
# Update workflow last activity
|
|
self.interface.updateWorkflow(thread_id, {
|
|
"lastActivity": getUtcTimestamp()
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error storing checkpoint: {e}", exc_info=True)
|
|
raise
|
|
|
|
def get(
|
|
self,
|
|
config: Dict[str, Any],
|
|
) -> Optional[Checkpoint]:
|
|
"""
|
|
Retrieve a checkpoint from the database.
|
|
|
|
Args:
|
|
config: LangGraph config (contains thread_id)
|
|
|
|
Returns:
|
|
Checkpoint if found, None otherwise
|
|
"""
|
|
try:
|
|
# Extract thread_id from config (maps to workflow_id)
|
|
thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id)
|
|
|
|
# Get workflow
|
|
workflow = self.interface.getWorkflow(thread_id)
|
|
if not workflow:
|
|
logger.debug(f"Workflow {thread_id} not found")
|
|
return None
|
|
|
|
# Get messages
|
|
messages = self.interface.getMessages(thread_id)
|
|
|
|
checkpoint_id = str(uuid.uuid4())
|
|
|
|
if not messages:
|
|
# Return empty checkpoint for new workflow
|
|
return {
|
|
"id": checkpoint_id,
|
|
"v": 1,
|
|
"ts": getUtcTimestamp(),
|
|
"channel_values": {
|
|
"messages": []
|
|
},
|
|
"channel_versions": {},
|
|
"versions_seen": {}
|
|
}
|
|
|
|
# Convert to LangChain messages
|
|
langchain_messages = self._convert_db_to_langchain_messages(messages)
|
|
|
|
# Build checkpoint
|
|
checkpoint = {
|
|
"id": checkpoint_id,
|
|
"v": 1,
|
|
"ts": getUtcTimestamp(),
|
|
"channel_values": {
|
|
"messages": langchain_messages
|
|
},
|
|
"channel_versions": {},
|
|
"versions_seen": {}
|
|
}
|
|
|
|
return checkpoint
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving checkpoint: {e}", exc_info=True)
|
|
return None
|
|
|
|
def list(
|
|
self,
|
|
config: Dict[str, Any],
|
|
filter: Optional[Dict[str, Any]] = None,
|
|
before: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
) -> List[Checkpoint]:
|
|
"""
|
|
List checkpoints (not fully implemented - returns current checkpoint).
|
|
|
|
Args:
|
|
config: LangGraph config
|
|
filter: Optional filter
|
|
before: Optional timestamp before which to list
|
|
limit: Optional limit on number of results
|
|
|
|
Returns:
|
|
List of checkpoints
|
|
"""
|
|
checkpoint = self.get(config)
|
|
if checkpoint:
|
|
return [checkpoint]
|
|
return []
|
|
|
|
def put_writes(
|
|
self,
|
|
config: Dict[str, Any],
|
|
writes: List[Tuple[str, Any]],
|
|
task_id: str,
|
|
) -> None:
|
|
"""
|
|
Store checkpoint writes (not used in current implementation).
|
|
|
|
Args:
|
|
config: LangGraph config
|
|
writes: List of write operations
|
|
task_id: Task ID
|
|
"""
|
|
# Not implemented - using put() instead
|
|
pass
|
|
|
|
async def aget_tuple(
|
|
self,
|
|
config: Dict[str, Any],
|
|
) -> Optional[CheckpointTuple]:
|
|
"""
|
|
Async version of get that returns tuple of (config, checkpoint, metadata).
|
|
|
|
Args:
|
|
config: LangGraph config (contains thread_id)
|
|
|
|
Returns:
|
|
CheckpointTuple with config, checkpoint and metadata if found, None otherwise
|
|
"""
|
|
checkpoint = self.get(config)
|
|
if checkpoint:
|
|
# Return checkpoint with metadata including step
|
|
# CheckpointMetadata is typically a TypedDict
|
|
# LangGraph expects 'step' in metadata
|
|
metadata: CheckpointMetadata = {
|
|
"step": 0 # Start at step 0, LangGraph will increment
|
|
}
|
|
return CheckpointTuple(
|
|
config=config,
|
|
checkpoint=checkpoint,
|
|
metadata=metadata,
|
|
parent_config=None, # No parent checkpoint for our implementation
|
|
pending_writes=None # No pending writes in our implementation
|
|
)
|
|
return None
|
|
|
|
async def aput(
|
|
self,
|
|
config: Dict[str, Any],
|
|
checkpoint: Checkpoint,
|
|
metadata: CheckpointMetadata,
|
|
new_versions: Dict[str, int],
|
|
) -> None:
|
|
"""
|
|
Async version of put.
|
|
|
|
Args:
|
|
config: LangGraph config (contains thread_id)
|
|
checkpoint: Checkpoint to store
|
|
metadata: Checkpoint metadata
|
|
new_versions: New version numbers
|
|
"""
|
|
self.put(config, checkpoint, metadata, new_versions)
|
|
|
|
async def alist(
|
|
self,
|
|
config: Dict[str, Any],
|
|
filter: Optional[Dict[str, Any]] = None,
|
|
before: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
) -> List[Checkpoint]:
|
|
"""
|
|
Async version of list.
|
|
|
|
Args:
|
|
config: LangGraph config
|
|
filter: Optional filter
|
|
before: Optional timestamp before which to list
|
|
limit: Optional limit on number of results
|
|
|
|
Returns:
|
|
List of checkpoints
|
|
"""
|
|
return self.list(config, filter, before, limit)
|
|
|
|
async def aput_writes(
|
|
self,
|
|
config: Dict[str, Any],
|
|
writes: List[Tuple[str, Any]],
|
|
task_id: str,
|
|
) -> None:
|
|
"""
|
|
Async version of put_writes.
|
|
Store checkpoint writes (not used in current implementation).
|
|
|
|
Args:
|
|
config: LangGraph config
|
|
writes: List of write operations
|
|
task_id: Task ID
|
|
"""
|
|
# Not implemented - using aput() instead
|
|
# This method is called by LangGraph but we handle writes through aput()
|
|
pass
|