# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Custom LangGraph checkpointer using existing database interface. Maps LangGraph state to existing message storage format. """ import contextvars import logging import uuid from typing import Any, Dict, List, Optional, Tuple, NamedTuple from datetime import datetime from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointMetadata # CheckpointTuple might not be directly importable, so we define it as a NamedTuple # Based on LangGraph's usage, it needs config, checkpoint, metadata, parent_config, and pending_writes class CheckpointTuple(NamedTuple): """Tuple containing config, checkpoint, metadata, parent_config, and pending_writes.""" config: Dict[str, Any] checkpoint: Checkpoint metadata: CheckpointMetadata parent_config: Optional[Dict[str, Any]] = None pending_writes: Optional[List[Tuple[str, Any]]] = None from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage from modules.features.chatbot.interfaceFeatureChatbot import getInterface as getChatbotInterface from modules.features.chatbot.interfaceFeatureChatbot import ChatbotMessage from modules.datamodels.datamodelUam import User from modules.shared.timeUtils import getUtcTimestamp logger = logging.getLogger(__name__) def _sanitize_llm_response(text: str) -> str: """Strip chat template tokens and trailing junk that some models leak.""" if not text or not isinstance(text, str): return text or "" for sentinel in ("<|im_start|>", "<|im_end|>", "<|endoftext|>", "<|user|>", "<|assistant|>"): if sentinel in text: text = text.split(sentinel)[0] return text.strip() class DatabaseCheckpointer(BaseCheckpointSaver): """ Custom LangGraph checkpointer that uses the chatbot's own database interface. Maps LangGraph thread_id to conversation.id; stores messages via interface (workflowId maps to conversationId). """ def __init__( self, user: User, workflow_id: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None, *, interface=None, ): """ Initialize the database checkpointer. Args: user: Current user for database access workflow_id: Workflow ID (maps to LangGraph thread_id) mandateId: Mandate ID for proper data isolation featureInstanceId: Feature instance ID for proper data isolation interface: Optional pre-created chatbot interface (avoids extra getInterface + DB init) """ self.user = user self.workflow_id = workflow_id self.interface = interface if interface is not None else getChatbotInterface( user, mandateId=mandateId, featureInstanceId=featureInstanceId ) def _convert_langchain_to_db_message( self, msg: BaseMessage, sequence_nr: int, round_number: int ) -> Dict[str, Any]: """ Convert LangChain message to database message format. Args: msg: LangChain message sequence_nr: Sequence number for ordering round_number: Round number in workflow Returns: Dictionary in database message format """ import uuid role = "user" content = "" if isinstance(msg, HumanMessage): role = "user" content = msg.content if isinstance(msg.content, str) else str(msg.content) elif isinstance(msg, AIMessage): role = "assistant" content = msg.content if isinstance(msg.content, str) else str(msg.content) elif isinstance(msg, SystemMessage): # System messages are stored but marked as system role = "system" content = msg.content if isinstance(msg.content, str) else str(msg.content) elif isinstance(msg, ToolMessage): # Tool messages are stored as assistant messages with tool info role = "assistant" content = f"Tool {msg.name}: {msg.content}" return { "id": f"msg_{uuid.uuid4()}", "workflowId": self.workflow_id, "message": content, "role": role, "status": "step" if sequence_nr > 1 else "first", "sequenceNr": sequence_nr, "publishedAt": getUtcTimestamp(), "roundNumber": round_number, "taskNumber": 0, "actionNumber": 0 } def _convert_db_to_langchain_messages( self, messages: List[ChatbotMessage] ) -> List[BaseMessage]: """ Convert database messages to LangChain messages. Args: messages: List of database ChatMessage objects Returns: List of LangChain BaseMessage objects """ langchain_messages = [] for msg in messages: if msg.role == "user": langchain_messages.append(HumanMessage(content=msg.message)) elif msg.role == "assistant": langchain_messages.append(AIMessage(content=msg.message)) elif msg.role == "system": langchain_messages.append(SystemMessage(content=msg.message)) # Skip other roles for now return langchain_messages def put( self, config: Dict[str, Any], checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: Dict[str, int], ) -> None: """ Store a checkpoint in the database. Args: config: LangGraph config (contains thread_id) checkpoint: Checkpoint to store metadata: Checkpoint metadata new_versions: New version numbers """ try: # Extract thread_id from config (maps to workflow_id) thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id) # Get current workflow to determine round number workflow = self.interface.getWorkflow(thread_id) if not workflow: logger.warning(f"Workflow {thread_id} not found, cannot store checkpoint") return round_number = workflow.currentRound if workflow else 1 # Extract messages from checkpoint state = checkpoint.get("channel_values", {}) messages = state.get("messages", []) if not messages: logger.debug(f"No messages in checkpoint for workflow {thread_id}") return # Get existing messages to determine what's already stored existing_messages = self.interface.getMessages(thread_id) existing_count = len(existing_messages) if existing_messages else 0 # Create a set of existing message content+role for quick lookup existing_content_set = set() if existing_messages: for existing_msg in existing_messages: # Create a unique key from role and message content content_key = (existing_msg.role, existing_msg.message) existing_content_set.add(content_key) # Filter checkpoint messages to only user/assistant (skip system) # Skip intermediate AIMessages with tool_calls (these are tool call requests, not final answers) checkpoint_user_assistant_messages = [] for msg in messages: if isinstance(msg, HumanMessage): # Always store user messages checkpoint_user_assistant_messages.append(msg) elif isinstance(msg, AIMessage): # Check if this message has tool_calls tool_calls = getattr(msg, "tool_calls", None) if tool_calls and len(tool_calls) > 0: logger.debug(f"Skipping intermediate AIMessage with tool_calls for workflow {thread_id}") continue # Skip agent_sql_plan output (raw SQL block) - only store agent_formulate final answer content = msg.content if isinstance(msg.content, str) else str(msg.content) cu = (content or "").strip().upper() if content and ( content.strip().startswith("```") or (cu.startswith("SELECT") and ("FROM" in cu or "JOIN" in cu)) ): logger.debug(f"Skipping intermediate SQL AIMessage for workflow {thread_id}") continue checkpoint_user_assistant_messages.append(msg) # Only store new messages that aren't already in the database new_messages_to_store = [] for msg in checkpoint_user_assistant_messages: role = "user" if isinstance(msg, HumanMessage) else "assistant" content = msg.content if isinstance(msg.content, str) else str(msg.content) if isinstance(msg, AIMessage): content = _sanitize_llm_response(content) if not content or not content.strip(): continue content_key = (role, content) if content_key not in existing_content_set: if isinstance(msg, AIMessage) and msg.content != content: msg = AIMessage(content=content) new_messages_to_store.append(msg) existing_content_set.add(content_key) # Store only the new messages if new_messages_to_store: for i, msg in enumerate(new_messages_to_store, 1): sequence_nr = existing_count + i # Convert to database format db_message_data = self._convert_langchain_to_db_message( msg, sequence_nr, round_number ) # Store the message try: self.interface.createMessage(db_message_data) logger.debug(f"Stored message {db_message_data['id']} for workflow {thread_id}") existing_count += 1 # Update count for next message except Exception as e: logger.error(f"Error storing message: {e}", exc_info=True) else: logger.debug(f"No new messages to store for workflow {thread_id} (existing: {existing_count}, checkpoint: {len(checkpoint_user_assistant_messages)})") # Update workflow last activity self.interface.updateWorkflow(thread_id, { "lastActivity": getUtcTimestamp() }) except Exception as e: logger.error(f"Error storing checkpoint: {e}", exc_info=True) raise def get( self, config: Dict[str, Any], ) -> Optional[Checkpoint]: """ Retrieve a checkpoint from the database. Args: config: LangGraph config (contains thread_id) Returns: Checkpoint if found, None otherwise """ try: # Extract thread_id from config (maps to workflow_id) thread_id = config.get("configurable", {}).get("thread_id", self.workflow_id) # Get workflow workflow = self.interface.getWorkflow(thread_id) if not workflow: logger.debug(f"Workflow {thread_id} not found") return None # Get messages messages = self.interface.getMessages(thread_id) checkpoint_id = str(uuid.uuid4()) if not messages: # Return empty checkpoint for new workflow return { "id": checkpoint_id, "v": 1, "ts": getUtcTimestamp(), "channel_values": { "messages": [] }, "channel_versions": {}, "versions_seen": {} } # Convert to LangChain messages langchain_messages = self._convert_db_to_langchain_messages(messages) # Build checkpoint checkpoint = { "id": checkpoint_id, "v": 1, "ts": getUtcTimestamp(), "channel_values": { "messages": langchain_messages }, "channel_versions": {}, "versions_seen": {} } return checkpoint except Exception as e: logger.error(f"Error retrieving checkpoint: {e}", exc_info=True) return None def list( self, config: Dict[str, Any], filter: Optional[Dict[str, Any]] = None, before: Optional[str] = None, limit: Optional[int] = None, ) -> List[Checkpoint]: """ List checkpoints (not fully implemented - returns current checkpoint). Args: config: LangGraph config filter: Optional filter before: Optional timestamp before which to list limit: Optional limit on number of results Returns: List of checkpoints """ checkpoint = self.get(config) if checkpoint: return [checkpoint] return [] def put_writes( self, config: Dict[str, Any], writes: List[Tuple[str, Any]], task_id: str, ) -> None: """ Store checkpoint writes (not used in current implementation). Args: config: LangGraph config writes: List of write operations task_id: Task ID """ # Not implemented - using put() instead pass async def aget_tuple( self, config: Dict[str, Any], ) -> Optional[CheckpointTuple]: """ Async version of get that returns tuple of (config, checkpoint, metadata). Args: config: LangGraph config (contains thread_id) Returns: CheckpointTuple with config, checkpoint and metadata if found, None otherwise """ checkpoint = self.get(config) if checkpoint: # Return checkpoint with metadata including step # CheckpointMetadata is typically a TypedDict # LangGraph expects 'step' in metadata metadata: CheckpointMetadata = { "step": 0 # Start at step 0, LangGraph will increment } return CheckpointTuple( config=config, checkpoint=checkpoint, metadata=metadata, parent_config=None, # No parent checkpoint for our implementation pending_writes=None # No pending writes in our implementation ) return None async def aput( self, config: Dict[str, Any], checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: Dict[str, int], ) -> None: """ Async version of put. Args: config: LangGraph config (contains thread_id) checkpoint: Checkpoint to store metadata: Checkpoint metadata new_versions: New version numbers """ self.put(config, checkpoint, metadata, new_versions) async def alist( self, config: Dict[str, Any], filter: Optional[Dict[str, Any]] = None, before: Optional[str] = None, limit: Optional[int] = None, ) -> List[Checkpoint]: """ Async version of list. Args: config: LangGraph config filter: Optional filter before: Optional timestamp before which to list limit: Optional limit on number of results Returns: List of checkpoints """ return self.list(config, filter, before, limit) async def aput_writes( self, config: Dict[str, Any], writes: List[Tuple[str, Any]], task_id: str, ) -> None: """ Async version of put_writes. Store checkpoint writes (not used in current implementation). Args: config: LangGraph config writes: List of write operations task_id: Task ID """ # Not implemented - using aput() instead # This method is called by LangGraph but we handle writes through aput() pass # ContextVar for per-request checkpointer (used by CheckpointerResolver for graph caching) _current_checkpointer: contextvars.ContextVar[Optional[BaseCheckpointSaver]] = contextvars.ContextVar( "chatbot_current_checkpointer", default=None ) def set_checkpointer(checkpointer: BaseCheckpointSaver) -> contextvars.Token: """Set the current request's checkpointer. Returns token to reset later.""" return _current_checkpointer.set(checkpointer) def reset_checkpointer(token: contextvars.Token) -> None: """Reset checkpointer to prior value. Safe when called from a different async context.""" try: _current_checkpointer.reset(token) except ValueError: # Token was created in a different context (e.g. after yield, generator cleanup) pass class CheckpointerResolver(BaseCheckpointSaver): """ Delegating checkpointer that reads the real checkpointer from context. Used for graph caching: the compiled graph uses this resolver; at invoke time the per-request checkpointer is set via set_checkpointer(). """ def _get_checkpointer(self) -> BaseCheckpointSaver: cp = _current_checkpointer.get() if cp is None: raise RuntimeError( "CheckpointerResolver: no checkpointer in context. " "Call set_checkpointer() before invoking the cached graph." ) return cp def put( self, config: Dict[str, Any], checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: Dict[str, int], ) -> None: self._get_checkpointer().put(config, checkpoint, metadata, new_versions) def get(self, config: Dict[str, Any]) -> Optional[Checkpoint]: return self._get_checkpointer().get(config) def list( self, config: Dict[str, Any], filter: Optional[Dict[str, Any]] = None, before: Optional[str] = None, limit: Optional[int] = None, ) -> List[Checkpoint]: return self._get_checkpointer().list(config, filter, before, limit) def put_writes( self, config: Dict[str, Any], writes: List[Tuple[str, Any]], task_id: str, ) -> None: self._get_checkpointer().put_writes(config, writes, task_id) async def aget_tuple(self, config: Dict[str, Any]) -> Optional[CheckpointTuple]: inner = self._get_checkpointer() if hasattr(inner, "aget_tuple"): return await inner.aget_tuple(config) checkpoint = inner.get(config) if checkpoint: metadata: CheckpointMetadata = {"step": 0} return CheckpointTuple( config=config, checkpoint=checkpoint, metadata=metadata, parent_config=None, pending_writes=None, ) return None async def aput( self, config: Dict[str, Any], checkpoint: Checkpoint, metadata: CheckpointMetadata, new_versions: Dict[str, int], ) -> None: inner = self._get_checkpointer() if hasattr(inner, "aput"): await inner.aput(config, checkpoint, metadata, new_versions) else: inner.put(config, checkpoint, metadata, new_versions) async def alist( self, config: Dict[str, Any], filter: Optional[Dict[str, Any]] = None, before: Optional[str] = None, limit: Optional[int] = None, ) -> List[Checkpoint]: inner = self._get_checkpointer() if hasattr(inner, "alist"): return await inner.alist(config, filter, before, limit) return inner.list(config, filter, before, limit) async def aput_writes( self, config: Dict[str, Any], writes: List[Tuple[str, Any]], task_id: str, ) -> None: inner = self._get_checkpointer() if hasattr(inner, "aput_writes"): await inner.aput_writes(config, writes, task_id)