# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Custom LangGraph checkpointer using existing database interface. Maps LangGraph state to existing message storage format. """ import logging import uuid from typing import Any, Dict, List, Optional, Tuple, NamedTuple from datetime import datetime from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointMetadata # CheckpointTuple might not be directly importable, so we define it as a NamedTuple # Based on LangGraph's usage, it needs config, checkpoint, metadata, parent_config, and pending_writes class CheckpointTuple(NamedTuple): """Tuple containing config, checkpoint, metadata, parent_config, and pending_writes.""" config: Dict[str, Any] checkpoint: Checkpoint metadata: CheckpointMetadata parent_config: Optional[Dict[str, Any]] = None pending_writes: Optional[List[Tuple[str, Any]]] = None from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage from modules.interfaces.interfaceDbChat import getInterface from modules.datamodels.datamodelChat import ChatMessage, ChatWorkflow from modules.datamodels.datamodelUam import User from modules.shared.timeUtils import getUtcTimestamp logger = logging.getLogger(__name__) class DatabaseCheckpointer(BaseCheckpointSaver): """ Custom LangGraph checkpointer that uses the existing database interface. Maps LangGraph thread_id to workflow.id and stores messages in the existing format. """ def __init__(self, user: User, workflow_id: str): """ Initialize the database checkpointer. Args: user: Current user for database access workflow_id: Workflow ID (maps to LangGraph thread_id) """ self.user = user self.workflow_id = workflow_id self.interface = getInterface(user) def _convert_langchain_to_db_message( self, msg: BaseMessage, sequence_nr: int, round_number: int ) -> Dict[str, Any]: """ Convert LangChain message to database message format. Args: msg: LangChain message sequence_nr: Sequence number for ordering round_number: Round number in workflow Returns: Dictionary in database message format """ import uuid role = "user" content = "" if isinstance(msg, HumanMessage): role = "user" content = msg.content if isinstance(msg.content, str) else str(msg.content) elif isinstance(msg, AIMessage): role = "assistant" content = msg.content if isinstance(msg.content, str) else str(msg.content) elif isinstance(msg, SystemMessage): # System messages are stored but marked as system role = "system" content = msg.content if isinstance(msg.content, str) else str(msg.content) elif isinstance(msg, ToolMessage): # Tool messages are stored as assistant messages with tool info role = "assistant" content = f"Tool {msg.name}: {msg.content}" return { "id": f"msg_{uuid.uuid4()}", "workflowId": self.workflow_id, "message": content, "role": role, "status": "step" if sequence_nr > 1 else "first", "sequenceNr": sequence_nr, "publishedAt": getUtcTimestamp(), "roundNumber": round_number, "taskNumber": 0, "actionNumber": 0 } def _convert_db_to_langchain_messages( self, messages: List[ChatMessage] ) -> List[BaseMessage]: """ Convert database messages to LangChain messages. Args: messages: List of database ChatMessage objects Returns: List of LangChain BaseMessage objects """ langchain_messages = [] for msg in messages: if msg.role == "user": langchain_messages.append(HumanMessage(content=msg.message)) elif msg.role == "assistant": langchain_messages.append(AIMessage(content=msg.message)) elif msg.role == "system": langchain_messages.append(SystemMessage(content=msg.message)) # Skip other roles for now return langchain_messages def put( self, config: Dict[str, Any], checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: Dict[str, int], ) -> None: """ Store a checkpoint in the database. Args: config: LangGraph config (contains thread_id) checkpoint: Checkpoint to store metadata: Checkpoint metadata new_versions: New version numbers """ try: # Extract thread_id from config (maps to workflow_id) thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id) # Get current workflow to determine round number workflow = self.interface.getWorkflow(thread_id) if not workflow: logger.warning(f"Workflow {thread_id} not found, cannot store checkpoint") return round_number = workflow.currentRound if workflow else 1 # Extract messages from checkpoint state = checkpoint.get("channel_values", {}) messages = state.get("messages", []) if not messages: logger.debug(f"No messages in checkpoint for workflow {thread_id}") return # Get existing messages to determine what's already stored existing_messages = self.interface.getMessages(thread_id) existing_count = len(existing_messages) if existing_messages else 0 # Create a set of existing message content+role for quick lookup existing_content_set = set() if existing_messages: for existing_msg in existing_messages: # Create a unique key from role and message content content_key = (existing_msg.role, existing_msg.message) existing_content_set.add(content_key) # Filter checkpoint messages to only user/assistant (skip system) # Skip intermediate AIMessages with tool_calls (these are tool call requests, not final answers) checkpoint_user_assistant_messages = [] for msg in messages: if isinstance(msg, HumanMessage): # Always store user messages checkpoint_user_assistant_messages.append(msg) elif isinstance(msg, AIMessage): # Check if this message has tool_calls tool_calls = getattr(msg, "tool_calls", None) # Skip messages with tool_calls - these are intermediate tool call requests if tool_calls and len(tool_calls) > 0: logger.debug(f"Skipping intermediate AIMessage with tool_calls for workflow {thread_id}") continue # Store all other AIMessages (final answers) checkpoint_user_assistant_messages.append(msg) # Only store new messages that aren't already in the database new_messages_to_store = [] for msg in checkpoint_user_assistant_messages: # Determine role role = "user" if isinstance(msg, HumanMessage) else "assistant" content = msg.content if isinstance(msg.content, str) else str(msg.content) # Skip empty messages (they might be status updates) if not content or not content.strip(): continue # Check if this message already exists content_key = (role, content) if content_key not in existing_content_set: new_messages_to_store.append(msg) existing_content_set.add(content_key) # Mark as seen to avoid duplicates in this batch # Store only the new messages if new_messages_to_store: for i, msg in enumerate(new_messages_to_store, 1): sequence_nr = existing_count + i # Convert to database format db_message_data = self._convert_langchain_to_db_message( msg, sequence_nr, round_number ) # Store the message try: self.interface.createMessage(db_message_data) logger.debug(f"Stored message {db_message_data['id']} for workflow {thread_id}") existing_count += 1 # Update count for next message except Exception as e: logger.error(f"Error storing message: {e}", exc_info=True) else: logger.debug(f"No new messages to store for workflow {thread_id} (existing: {existing_count}, checkpoint: {len(checkpoint_user_assistant_messages)})") # Update workflow last activity self.interface.updateWorkflow(thread_id, { "lastActivity": getUtcTimestamp() }) except Exception as e: logger.error(f"Error storing checkpoint: {e}", exc_info=True) raise def get( self, config: Dict[str, Any], ) -> Optional[Checkpoint]: """ Retrieve a checkpoint from the database. Args: config: LangGraph config (contains thread_id) Returns: Checkpoint if found, None otherwise """ try: # Extract thread_id from config (maps to workflow_id) thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id) # Get workflow workflow = self.interface.getWorkflow(thread_id) if not workflow: logger.debug(f"Workflow {thread_id} not found") return None # Get messages messages = self.interface.getMessages(thread_id) checkpoint_id = str(uuid.uuid4()) if not messages: # Return empty checkpoint for new workflow return { "id": checkpoint_id, "v": 1, "ts": getUtcTimestamp(), "channel_values": { "messages": [] }, "channel_versions": {}, "versions_seen": {} } # Convert to LangChain messages langchain_messages = self._convert_db_to_langchain_messages(messages) # Build checkpoint checkpoint = { "id": checkpoint_id, "v": 1, "ts": getUtcTimestamp(), "channel_values": { "messages": langchain_messages }, "channel_versions": {}, "versions_seen": {} } return checkpoint except Exception as e: logger.error(f"Error retrieving checkpoint: {e}", exc_info=True) return None def list( self, config: Dict[str, Any], filter: Optional[Dict[str, Any]] = None, before: Optional[str] = None, limit: Optional[int] = None, ) -> List[Checkpoint]: """ List checkpoints (not fully implemented - returns current checkpoint). Args: config: LangGraph config filter: Optional filter before: Optional timestamp before which to list limit: Optional limit on number of results Returns: List of checkpoints """ checkpoint = self.get(config) if checkpoint: return [checkpoint] return [] def put_writes( self, config: Dict[str, Any], writes: List[Tuple[str, Any]], task_id: str, ) -> None: """ Store checkpoint writes (not used in current implementation). Args: config: LangGraph config writes: List of write operations task_id: Task ID """ # Not implemented - using put() instead pass async def aget_tuple( self, config: Dict[str, Any], ) -> Optional[CheckpointTuple]: """ Async version of get that returns tuple of (config, checkpoint, metadata). Args: config: LangGraph config (contains thread_id) Returns: CheckpointTuple with config, checkpoint and metadata if found, None otherwise """ checkpoint = self.get(config) if checkpoint: # Return checkpoint with metadata including step # CheckpointMetadata is typically a TypedDict # LangGraph expects 'step' in metadata metadata: CheckpointMetadata = { "step": 0 # Start at step 0, LangGraph will increment } return CheckpointTuple( config=config, checkpoint=checkpoint, metadata=metadata, parent_config=None, # No parent checkpoint for our implementation pending_writes=None # No pending writes in our implementation ) return None async def aput( self, config: Dict[str, Any], checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: Dict[str, int], ) -> None: """ Async version of put. Args: config: LangGraph config (contains thread_id) checkpoint: Checkpoint to store metadata: Checkpoint metadata new_versions: New version numbers """ self.put(config, checkpoint, metadata, new_versions) async def alist( self, config: Dict[str, Any], filter: Optional[Dict[str, Any]] = None, before: Optional[str] = None, limit: Optional[int] = None, ) -> List[Checkpoint]: """ Async version of list. Args: config: LangGraph config filter: Optional filter before: Optional timestamp before which to list limit: Optional limit on number of results Returns: List of checkpoints """ return self.list(config, filter, before, limit) async def aput_writes( self, config: Dict[str, Any], writes: List[Tuple[str, Any]], task_id: str, ) -> None: """ Async version of put_writes. Store checkpoint writes (not used in current implementation). Args: config: LangGraph config writes: List of write operations task_id: Task ID """ # Not implemented - using aput() instead # This method is called by LangGraph but we handle writes through aput() pass