from typing import AsyncIterator import uuid from fastapi import Request from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy import String, Uuid class Base(DeclarativeBase): pass # User Thread Mapping Table class UserThreadMapping(Base): """Mapping of users to their chat threads. Used to keep track of which user owns which chat thread. Also stores meta data like thread name. 1:N relationship between user and thread. A thread belongs to exactly one user. A user can have multiple threads. Thread_id is unique in the table. """ __tablename__ = "userThreads" id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) userId: Mapped[str] = mapped_column(String(255), nullable=False) threadId: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) threadName: Mapped[str] = mapped_column(String(255), nullable=False) # Dependency that pulls the sessionmaker off app.state # This is set in app.py on startup in @asynccontextmanager # TODO: If we use SQLAlchemy in other places, we can move this to a shared module async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]: SessionLocal: async_sessionmaker[AsyncSession] = ( request.app.state.checkpoint_sessionmaker ) async with SessionLocal() as session: yield session # Optional helper to init tables at startup (demo only) async def init_models(engine) -> None: async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all)