diff --git a/app.py b/app.py index b75b59e8..384d6bfa 100644 --- a/app.py +++ b/app.py @@ -275,6 +275,14 @@ async def lifespan(app: FastAPI): # NOTE: Might need Alembic migrations in the future await init_chatbot_models(engine) + # --- Sync tools from registry to database --- + from modules.features.chatBot.database import sync_tools_from_registry + + async with SessionLocal() as session: + await sync_tools_from_registry(session) + await session.commit() + logger.info("Tools synced from registry to database") + # --- Initialize LangGraph checkpointer --- from modules.features.chatBot.utils.checkpointer import ( diff --git a/modules/features/chatBot/database.py b/modules/features/chatBot/database.py index 5473fa0f..1dc4ebe6 100644 --- a/modules/features/chatBot/database.py +++ b/modules/features/chatBot/database.py @@ -5,13 +5,71 @@ from datetime import datetime, timezone from fastapi import Request from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from sqlalchemy import String, Uuid, DateTime +from sqlalchemy import String, Uuid, DateTime, Boolean, UniqueConstraint class Base(DeclarativeBase): pass +# Tools Table +class Tool(Base): + """Available chatbot tools. + + Stores information about all available tools that can be assigned to users. + Each tool has a unique tool_id that corresponds to the registry tool_id. + """ + + __tablename__ = "tools" + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + tool_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + label: Mapped[str] = mapped_column(String(255), nullable=False) + category: Mapped[str] = mapped_column(String(50), nullable=False) + description: Mapped[str] = mapped_column(String(1000), nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + date_created: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) + date_updated: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) + + +# User-Tool Mapping Table +class UserToolMapping(Base): + """Mapping of users to their available tools. + + Many-to-many relationship between users and tools. + - One user can have multiple tools + - One tool can be assigned to multiple users + + The combination of user_id and tool_id is unique. + """ + + __tablename__ = "user_tools" + __table_args__ = (UniqueConstraint("user_id", "tool_id", name="uq_user_tool"),) + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + user_id: Mapped[str] = mapped_column(String(255), nullable=False) + tool_id: Mapped[uuid.UUID] = mapped_column(Uuid, nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + date_granted: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) + date_updated: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) + + # User Thread Mapping Table class UserThreadMapping(Base): """Mapping of users to their chat threads. @@ -56,3 +114,84 @@ async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]: async def init_models(engine) -> None: async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) + + +async def sync_tools_from_registry(session: AsyncSession) -> None: + """Sync tools from tool registry to database. + + This function: + - Adds new tools from the registry to the database + - Updates existing tools with current registry information + - Marks tools not present in the registry as inactive + + Should be called on application startup after database initialization. + + Args: + session: Active database session + """ + import logging + from sqlalchemy import select + + from modules.features.chatBot.utils.toolRegistry import get_registry + + logger = logging.getLogger(__name__) + logger.info("Syncing tools from registry to database...") + + # Get all tools from the registry + registry = get_registry() + registry_tools = registry.get_all_tools() + + # Create a set of tool_ids from the registry + registry_tool_ids = {tool.tool_id for tool in registry_tools} + + logger.info(f"Found {len(registry_tools)} tools in registry") + + # Get all existing tools from the database + result = await session.execute(select(Tool)) + db_tools = result.scalars().all() + db_tools_by_tool_id = {tool.tool_id: tool for tool in db_tools} + + logger.info(f"Found {len(db_tools)} tools in database") + + # Track changes + added_count = 0 + updated_count = 0 + deactivated_count = 0 + + # Sync tools from registry to database + for registry_tool in registry_tools: + if registry_tool.tool_id in db_tools_by_tool_id: + # Tool exists - update it + # Preserve label and description (user-editable fields) + db_tool = db_tools_by_tool_id[registry_tool.tool_id] + db_tool.name = registry_tool.name + db_tool.category = registry_tool.category + db_tool.is_active = True + db_tool.date_updated = datetime.now(timezone.utc) + updated_count += 1 + logger.debug(f"Updated tool: {registry_tool.tool_id}") + else: + # Tool doesn't exist - create it + new_tool = Tool( + tool_id=registry_tool.tool_id, + name=registry_tool.name, + label=registry_tool.tool_id, # Use tool_id as label per spec + category=registry_tool.category, + description=registry_tool.description or "", + is_active=True, + ) + session.add(new_tool) + added_count += 1 + logger.debug(f"Added new tool: {registry_tool.tool_id}") + + # Mark tools not in registry as inactive + for db_tool in db_tools: + if db_tool.tool_id not in registry_tool_ids and db_tool.is_active: + db_tool.is_active = False + db_tool.date_updated = datetime.now(timezone.utc) + deactivated_count += 1 + logger.debug(f"Deactivated tool not in registry: {db_tool.tool_id}") + + logger.info( + f"Tool sync complete: {added_count} added, {updated_count} updated, {deactivated_count} deactivated" + )