106 lines
3.5 KiB
Python
106 lines
3.5 KiB
Python
"""PostgreSQL checkpointer utilities for LangGraph memory."""
|
|
|
|
import sys
|
|
import asyncio
|
|
import logging
|
|
from typing import Optional
|
|
|
|
# Fix for Windows asyncio compatibility with psycopg (backup in case app.py fix didn't apply)
|
|
if sys.platform == 'win32':
|
|
try:
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
except RuntimeError:
|
|
pass # Already set
|
|
|
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
from psycopg_pool import AsyncConnectionPool
|
|
from psycopg.rows import dict_row
|
|
from modules.shared.configuration import APP_CONFIG
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global checkpointer instance
|
|
_checkpointer_instance: Optional[AsyncPostgresSaver] = None
|
|
_connection_pool: Optional[AsyncConnectionPool] = None
|
|
|
|
|
|
async def initialize_checkpointer() -> None:
|
|
"""Initialize the PostgreSQL checkpointer for LangGraph.
|
|
|
|
This should be called during application startup.
|
|
Creates a connection pool and PostgresSaver instance.
|
|
"""
|
|
global _checkpointer_instance, _connection_pool
|
|
|
|
if _checkpointer_instance is not None:
|
|
logger.info("Checkpointer already initialized")
|
|
return
|
|
|
|
try:
|
|
# Get database configuration from environment
|
|
host = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_HOST", "localhost")
|
|
database = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_DATABASE", "poweron_chat")
|
|
user = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_USER", "poweron_dev")
|
|
password = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PASSWORD_SECRET")
|
|
port = APP_CONFIG.get("LANGGRAPH_CHECKPOINT_DB_PORT", "5432")
|
|
|
|
# Build connection string
|
|
connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}"
|
|
|
|
# Create async connection pool
|
|
_connection_pool = AsyncConnectionPool(
|
|
conninfo=connection_string,
|
|
min_size=2,
|
|
max_size=10,
|
|
kwargs={"autocommit": True, "row_factory": dict_row},
|
|
)
|
|
|
|
# Initialize the connection pool
|
|
await _connection_pool.open()
|
|
|
|
# Create AsyncPostgresSaver with the pool
|
|
_checkpointer_instance = AsyncPostgresSaver(_connection_pool)
|
|
|
|
# Setup the checkpointer (creates tables if needed)
|
|
await _checkpointer_instance.setup()
|
|
|
|
logger.info("PostgreSQL checkpointer initialized successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize PostgreSQL checkpointer: {str(e)}")
|
|
raise
|
|
|
|
|
|
async def close_checkpointer() -> None:
|
|
"""Close the checkpointer and connection pool.
|
|
|
|
This should be called during application shutdown.
|
|
"""
|
|
global _checkpointer_instance, _connection_pool
|
|
|
|
if _connection_pool is not None:
|
|
try:
|
|
await _connection_pool.close()
|
|
logger.info("PostgreSQL checkpointer connection pool closed")
|
|
except Exception as e:
|
|
logger.error(f"Error closing checkpointer connection pool: {str(e)}")
|
|
|
|
_checkpointer_instance = None
|
|
_connection_pool = None
|
|
|
|
|
|
def get_checkpointer() -> AsyncPostgresSaver:
|
|
"""Get the global PostgreSQL checkpointer instance.
|
|
|
|
Returns:
|
|
The initialized AsyncPostgresSaver instance
|
|
|
|
Raises:
|
|
RuntimeError: If checkpointer is not initialized
|
|
"""
|
|
if _checkpointer_instance is None:
|
|
raise RuntimeError(
|
|
"PostgreSQL checkpointer not initialized. "
|
|
"Call initialize_checkpointer() during application startup."
|
|
)
|
|
return _checkpointer_instance
|