fix: typo; async checkpointer postgres
This commit is contained in:
parent
33f8ff1b5e
commit
dd16efb860
2 changed files with 9 additions and 7 deletions
|
|
@ -44,7 +44,7 @@ def get_langchain_model(*, model_name: str) -> ChatAnthropic:
|
||||||
"""
|
"""
|
||||||
# Model name mapping
|
# Model name mapping
|
||||||
model_mapping = {
|
model_mapping = {
|
||||||
"claude_4_5": "claude-4-5-sonnet",
|
"claude_4_5": "claude-sonnet-4-5",
|
||||||
# Add more mappings as needed
|
# Add more mappings as needed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,15 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langgraph.checkpoint.postgres import PostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
from psycopg_pool import AsyncConnectionPool
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
from psycopg.rows import dict_row
|
||||||
from modules.shared.configuration import APP_CONFIG
|
from modules.shared.configuration import APP_CONFIG
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Global checkpointer instance
|
# Global checkpointer instance
|
||||||
_checkpointer_instance: Optional[PostgresSaver] = None
|
_checkpointer_instance: Optional[AsyncPostgresSaver] = None
|
||||||
_connection_pool: Optional[AsyncConnectionPool] = None
|
_connection_pool: Optional[AsyncConnectionPool] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -42,13 +43,14 @@ async def initialize_checkpointer() -> None:
|
||||||
conninfo=connection_string,
|
conninfo=connection_string,
|
||||||
min_size=2,
|
min_size=2,
|
||||||
max_size=10,
|
max_size=10,
|
||||||
|
kwargs={"autocommit": True, "row_factory": dict_row},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the connection pool
|
# Initialize the connection pool
|
||||||
await _connection_pool.open()
|
await _connection_pool.open()
|
||||||
|
|
||||||
# Create PostgresSaver with the pool
|
# Create AsyncPostgresSaver with the pool
|
||||||
_checkpointer_instance = PostgresSaver(_connection_pool)
|
_checkpointer_instance = AsyncPostgresSaver(_connection_pool)
|
||||||
|
|
||||||
# Setup the checkpointer (creates tables if needed)
|
# Setup the checkpointer (creates tables if needed)
|
||||||
await _checkpointer_instance.setup()
|
await _checkpointer_instance.setup()
|
||||||
|
|
@ -78,11 +80,11 @@ async def close_checkpointer() -> None:
|
||||||
_connection_pool = None
|
_connection_pool = None
|
||||||
|
|
||||||
|
|
||||||
def get_checkpointer() -> PostgresSaver:
|
def get_checkpointer() -> AsyncPostgresSaver:
|
||||||
"""Get the global PostgreSQL checkpointer instance.
|
"""Get the global PostgreSQL checkpointer instance.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The initialized PostgresSaver instance
|
The initialized AsyncPostgresSaver instance
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If checkpointer is not initialized
|
RuntimeError: If checkpointer is not initialized
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue