197 lines
7.1 KiB
Python
197 lines
7.1 KiB
Python
from typing import AsyncIterator
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
|
|
from fastapi import Request
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
from sqlalchemy import String, Uuid, DateTime, Boolean, UniqueConstraint
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
|
|
# Tools Table
|
|
class Tool(Base):
|
|
"""Available chatbot tools.
|
|
|
|
Stores information about all available tools that can be assigned to users.
|
|
Each tool has a unique tool_id that corresponds to the registry tool_id.
|
|
"""
|
|
|
|
__tablename__ = "tools"
|
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
|
tool_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
label: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
category: Mapped[str] = mapped_column(String(50), nullable=False)
|
|
description: Mapped[str] = mapped_column(String(1000), nullable=False)
|
|
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
|
date_created: Mapped[datetime] = mapped_column(
|
|
DateTime(timezone=True),
|
|
nullable=False,
|
|
default=lambda: datetime.now(timezone.utc),
|
|
)
|
|
date_updated: Mapped[datetime] = mapped_column(
|
|
DateTime(timezone=True),
|
|
nullable=False,
|
|
default=lambda: datetime.now(timezone.utc),
|
|
)
|
|
|
|
|
|
# User-Tool Mapping Table
|
|
class UserToolMapping(Base):
|
|
"""Mapping of users to their available tools.
|
|
|
|
Many-to-many relationship between users and tools.
|
|
- One user can have multiple tools
|
|
- One tool can be assigned to multiple users
|
|
|
|
The combination of user_id and tool_id is unique.
|
|
"""
|
|
|
|
__tablename__ = "user_tools"
|
|
__table_args__ = (UniqueConstraint("user_id", "tool_id", name="uq_user_tool"),)
|
|
|
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
|
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
tool_id: Mapped[uuid.UUID] = mapped_column(Uuid, nullable=False)
|
|
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
|
date_granted: Mapped[datetime] = mapped_column(
|
|
DateTime(timezone=True),
|
|
nullable=False,
|
|
default=lambda: datetime.now(timezone.utc),
|
|
)
|
|
date_updated: Mapped[datetime] = mapped_column(
|
|
DateTime(timezone=True),
|
|
nullable=False,
|
|
default=lambda: datetime.now(timezone.utc),
|
|
)
|
|
|
|
|
|
# User Thread Mapping Table
|
|
class UserThreadMapping(Base):
|
|
"""Mapping of users to their chat threads.
|
|
|
|
Used to keep track of which user owns which chat thread.
|
|
Also stores meta data like thread name.
|
|
|
|
1:N relationship between user and thread. A thread belongs to exactly one user.
|
|
A user can have multiple threads.
|
|
Thread_id is unique in the table.
|
|
"""
|
|
|
|
__tablename__ = "user_threads"
|
|
id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4)
|
|
user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
thread_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
|
thread_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
date_created: Mapped[datetime] = mapped_column(
|
|
DateTime(timezone=True),
|
|
nullable=False,
|
|
default=lambda: datetime.now(timezone.utc),
|
|
)
|
|
date_updated: Mapped[datetime] = mapped_column(
|
|
DateTime(timezone=True),
|
|
nullable=False,
|
|
default=lambda: datetime.now(timezone.utc),
|
|
)
|
|
|
|
|
|
# Dependency that pulls the sessionmaker off app.state
|
|
# This is set in app.py on startup in @asynccontextmanager
|
|
# TODO: If we use SQLAlchemy in other places, we can move this to a shared module
|
|
async def get_async_db_session(request: Request) -> AsyncIterator[AsyncSession]:
|
|
SessionLocal: async_sessionmaker[AsyncSession] = (
|
|
request.app.state.checkpoint_sessionmaker
|
|
)
|
|
async with SessionLocal() as session:
|
|
yield session
|
|
|
|
|
|
# Optional helper to init tables at startup (demo only)
|
|
async def init_models(engine) -> None:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
|
|
async def sync_tools_from_registry(session: AsyncSession) -> None:
|
|
"""Sync tools from tool registry to database.
|
|
|
|
This function:
|
|
- Adds new tools from the registry to the database
|
|
- Updates existing tools with current registry information
|
|
- Marks tools not present in the registry as inactive
|
|
|
|
Should be called on application startup after database initialization.
|
|
|
|
Args:
|
|
session: Active database session
|
|
"""
|
|
import logging
|
|
from sqlalchemy import select
|
|
|
|
from modules.features.chatBot.utils.toolRegistry import get_registry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.info("Syncing tools from registry to database...")
|
|
|
|
# Get all tools from the registry
|
|
registry = get_registry()
|
|
registry_tools = registry.get_all_tools()
|
|
|
|
# Create a set of tool_ids from the registry
|
|
registry_tool_ids = {tool.tool_id for tool in registry_tools}
|
|
|
|
logger.info(f"Found {len(registry_tools)} tools in registry")
|
|
|
|
# Get all existing tools from the database
|
|
result = await session.execute(select(Tool))
|
|
db_tools = result.scalars().all()
|
|
db_tools_by_tool_id = {tool.tool_id: tool for tool in db_tools}
|
|
|
|
logger.info(f"Found {len(db_tools)} tools in database")
|
|
|
|
# Track changes
|
|
added_count = 0
|
|
updated_count = 0
|
|
deactivated_count = 0
|
|
|
|
# Sync tools from registry to database
|
|
for registry_tool in registry_tools:
|
|
if registry_tool.tool_id in db_tools_by_tool_id:
|
|
# Tool exists - update it
|
|
# Preserve label and description (user-editable fields)
|
|
db_tool = db_tools_by_tool_id[registry_tool.tool_id]
|
|
db_tool.name = registry_tool.name
|
|
db_tool.category = registry_tool.category
|
|
db_tool.is_active = True
|
|
db_tool.date_updated = datetime.now(timezone.utc)
|
|
updated_count += 1
|
|
logger.debug(f"Updated tool: {registry_tool.tool_id}")
|
|
else:
|
|
# Tool doesn't exist - create it
|
|
new_tool = Tool(
|
|
tool_id=registry_tool.tool_id,
|
|
name=registry_tool.name,
|
|
label=registry_tool.tool_id, # Use tool_id as label per spec
|
|
category=registry_tool.category,
|
|
description=registry_tool.description or "",
|
|
is_active=True,
|
|
)
|
|
session.add(new_tool)
|
|
added_count += 1
|
|
logger.debug(f"Added new tool: {registry_tool.tool_id}")
|
|
|
|
# Mark tools not in registry as inactive
|
|
for db_tool in db_tools:
|
|
if db_tool.tool_id not in registry_tool_ids and db_tool.is_active:
|
|
db_tool.is_active = False
|
|
db_tool.date_updated = datetime.now(timezone.utc)
|
|
deactivated_count += 1
|
|
logger.debug(f"Deactivated tool not in registry: {db_tool.tool_id}")
|
|
|
|
logger.info(
|
|
f"Tool sync complete: {added_count} added, {updated_count} updated, {deactivated_count} deactivated"
|
|
)
|