feat: add db query router
This commit is contained in:
parent
ed8dd344d7
commit
9bc7cd774c
11 changed files with 449 additions and 27 deletions
1
src/dataprocessor/__init__.py
Normal file
1
src/dataprocessor/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# noqa
|
||||
1
src/dataprocessor/domain/__init__.py
Normal file
1
src/dataprocessor/domain/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# noqa
|
||||
|
|
@ -1,11 +1,9 @@
|
|||
from msal import ConfidentialClientApplication, SerializableTokenCache
|
||||
import anyio # comes with FastAPI via Starlette/AnyIO
|
||||
from dataclasses import dataclass
|
||||
from settings import settings
|
||||
from src.settings import settings
|
||||
import pandas as pd
|
||||
import httpx
|
||||
import re
|
||||
import sqlite3
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -76,24 +74,6 @@ class PowerBIReader:
|
|||
df.columns = [_strip_qual(c) for c in df.columns]
|
||||
return df
|
||||
|
||||
async def to_sqlite(
|
||||
self, db_path: str, table_name: str | None = None, if_exists: str = "replace"
|
||||
) -> int:
|
||||
"""
|
||||
Reads from Power BI and writes into SQLite. Returns row count written.
|
||||
"""
|
||||
df = await self.read_data()
|
||||
if df.empty:
|
||||
return 0
|
||||
|
||||
# Normalize SQLite table name
|
||||
tn = table_name or self.table_name
|
||||
tn = re.sub(r"[^\w]+", "_", tn).lower().strip("_")
|
||||
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
df.to_sql(tn, conn, if_exists=if_exists, index=False)
|
||||
return len(df)
|
||||
|
||||
@staticmethod
|
||||
def _get_access_token_sync(
|
||||
tenant_id: str,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||
from src.settings import settings
|
||||
from src.dataprocessor.domain.powerbi_reader import PowerBIReader
|
||||
from src.dataprocessor.domain.preprocessor import Preprocessor
|
||||
from dataprocessor.domain.base_datasaver import BaseDataSaver
|
||||
from src.dataprocessor.domain.base_datasaver import BaseDataSaver
|
||||
from src.dataprocessor.domain.sqlite_datasaver import SQLiteDataSaver
|
||||
|
||||
|
||||
|
|
|
|||
1
src/dataquery/__init__.py
Normal file
1
src/dataquery/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# noqa
|
||||
|
|
@ -1 +1,85 @@
|
|||
# TODO: Endpoint for SQL queries (read-only)
|
||||
"""Router for data query endpoints."""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
|
||||
from src.dataquery.schemas import (
|
||||
DatabaseSchemaResponse,
|
||||
SqlQueryRequest,
|
||||
SqlQueryResponse,
|
||||
TableSchemaResponse,
|
||||
)
|
||||
from src.dataquery.service import DataQueryService
|
||||
from src.dependencies import require_db_api_key
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/query", response_model=SqlQueryResponse)
|
||||
async def execute_sql_query(
|
||||
request: SqlQueryRequest,
|
||||
_: None = Depends(require_db_api_key),
|
||||
) -> SqlQueryResponse:
|
||||
"""Execute a SELECT SQL query against the database.
|
||||
|
||||
This endpoint only allows SELECT queries for security reasons.
|
||||
All other SQL operations (INSERT, UPDATE, DELETE, etc.) are blocked.
|
||||
|
||||
Args:
|
||||
request: The SQL query request containing the query string.
|
||||
|
||||
Returns:
|
||||
SqlQueryResponse with query results or error information.
|
||||
"""
|
||||
service = await DataQueryService.create()
|
||||
try:
|
||||
result = await service.execute_query(query=request.query)
|
||||
return result
|
||||
finally:
|
||||
await service.close()
|
||||
|
||||
|
||||
@router.get("/schema", response_model=DatabaseSchemaResponse)
|
||||
async def get_database_schema(
|
||||
_: None = Depends(require_db_api_key),
|
||||
) -> DatabaseSchemaResponse:
|
||||
"""Get information about all tables in the database.
|
||||
|
||||
Returns a list of all tables with their names and row counts.
|
||||
|
||||
Returns:
|
||||
DatabaseSchemaResponse with table information.
|
||||
"""
|
||||
service = await DataQueryService.create()
|
||||
try:
|
||||
result = await service.get_database_schema()
|
||||
return result
|
||||
finally:
|
||||
await service.close()
|
||||
|
||||
|
||||
@router.get("/schema/{table_name}", response_model=TableSchemaResponse)
|
||||
async def get_table_schema(
|
||||
table_name: str = Path(..., description="Name of the table to inspect"),
|
||||
_: None = Depends(require_db_api_key),
|
||||
) -> TableSchemaResponse:
|
||||
"""Get detailed information about a specific table.
|
||||
|
||||
Returns column information, data types, constraints, and sample data
|
||||
for the specified table.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to inspect.
|
||||
|
||||
Returns:
|
||||
TableSchemaResponse with table structure and sample data.
|
||||
"""
|
||||
service = await DataQueryService.create()
|
||||
try:
|
||||
result = await service.get_table_schema(table_name=table_name)
|
||||
return result
|
||||
finally:
|
||||
await service.close()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
"""Schemas for the data query service."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SqlQueryRequest(BaseModel):
|
||||
"""Request schema for SQL queries."""
|
||||
|
||||
query: str = Field(
|
||||
..., description="SQL query to execute (SELECT statements only)", min_length=1
|
||||
)
|
||||
|
||||
|
||||
class SqlQueryResponse(BaseModel):
|
||||
"""Response schema for SQL query results."""
|
||||
|
||||
success: bool = Field(
|
||||
..., description="Indicates if the query was executed successfully"
|
||||
)
|
||||
data: List[Dict[str, Any]] = Field(
|
||||
..., description="Query results as list of dictionaries"
|
||||
)
|
||||
columns: List[str] = Field(..., description="Column names in the result set")
|
||||
row_count: int = Field(..., description="Number of rows returned")
|
||||
message: Optional[str] = Field(
|
||||
None, description="Additional information or error message"
|
||||
)
|
||||
|
||||
|
||||
class TableInfo(BaseModel):
|
||||
"""Information about a database table."""
|
||||
|
||||
name: str = Field(..., description="Table name")
|
||||
row_count: int = Field(..., description="Number of rows in the table")
|
||||
|
||||
|
||||
class DatabaseSchemaResponse(BaseModel):
|
||||
"""Response schema for database schema information."""
|
||||
|
||||
success: bool = Field(
|
||||
..., description="Indicates if the schema retrieval was successful"
|
||||
)
|
||||
tables: List[TableInfo] = Field(..., description="List of tables in the database")
|
||||
table_count: int = Field(..., description="Total number of tables")
|
||||
|
||||
|
||||
class ColumnInfo(BaseModel):
|
||||
"""Information about a table column."""
|
||||
|
||||
name: str = Field(..., description="Column name")
|
||||
type: str = Field(..., description="Column data type")
|
||||
nullable: bool = Field(..., description="Whether the column can be null")
|
||||
primary_key: bool = Field(..., description="Whether the column is a primary key")
|
||||
|
||||
|
||||
class TableSchemaResponse(BaseModel):
|
||||
"""Response schema for table structure information."""
|
||||
|
||||
success: bool = Field(
|
||||
..., description="Indicates if the table schema retrieval was successful"
|
||||
)
|
||||
table_name: str = Field(..., description="Name of the table")
|
||||
columns: List[ColumnInfo] = Field(..., description="List of columns in the table")
|
||||
row_count: int = Field(..., description="Number of rows in the table")
|
||||
sample_data: List[Dict[str, Any]] = Field(
|
||||
..., description="Sample rows from the table (up to 5 rows)"
|
||||
)
|
||||
244
src/dataquery/service.py
Normal file
244
src/dataquery/service.py
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
"""Service for handling database queries and schema operations."""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import text, inspect
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from src.dataquery.schemas import (
|
||||
ColumnInfo,
|
||||
DatabaseSchemaResponse,
|
||||
SqlQueryResponse,
|
||||
TableInfo,
|
||||
TableSchemaResponse,
|
||||
)
|
||||
from src.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataQueryService:
|
||||
"""Service for executing database queries and retrieving schema information."""
|
||||
|
||||
engine: AsyncEngine
|
||||
|
||||
@classmethod
|
||||
async def create(cls) -> "DataQueryService":
|
||||
"""Create a new instance of DataQueryService."""
|
||||
db_url = f"sqlite+aiosqlite:///{settings.DB_PATH}"
|
||||
engine = create_async_engine(db_url, future=True)
|
||||
return cls(engine=engine)
|
||||
|
||||
def _validate_select_query(self, *, query: str) -> bool:
|
||||
"""Validate that the query is a SELECT statement only.
|
||||
|
||||
Args:
|
||||
query: The SQL query to validate.
|
||||
|
||||
Returns:
|
||||
True if the query is valid (SELECT only), False otherwise.
|
||||
"""
|
||||
# Remove comments and normalize whitespace
|
||||
cleaned_query = re.sub(r"--.*$", "", query, flags=re.MULTILINE)
|
||||
cleaned_query = re.sub(r"/\*.*?\*/", "", cleaned_query, flags=re.DOTALL)
|
||||
cleaned_query = re.sub(r"\s+", " ", cleaned_query.strip())
|
||||
|
||||
# Check if query starts with SELECT (case insensitive)
|
||||
if not re.match(r"^\s*SELECT\s", cleaned_query, re.IGNORECASE):
|
||||
return False
|
||||
|
||||
# Check for dangerous keywords that shouldn't be in SELECT queries
|
||||
dangerous_keywords = [
|
||||
r"\bINSERT\b",
|
||||
r"\bUPDATE\b",
|
||||
r"\bDELETE\b",
|
||||
r"\bDROP\b",
|
||||
r"\bCREATE\b",
|
||||
r"\bALTER\b",
|
||||
r"\bTRUNCATE\b",
|
||||
r"\bREPLACE\b",
|
||||
r"\bEXEC\b",
|
||||
r"\bEXECUTE\b",
|
||||
]
|
||||
|
||||
for keyword in dangerous_keywords:
|
||||
if re.search(keyword, cleaned_query, re.IGNORECASE):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def execute_query(self, *, query: str) -> SqlQueryResponse:
|
||||
"""Execute a SELECT query and return the results.
|
||||
|
||||
Args:
|
||||
query: The SQL query to execute.
|
||||
|
||||
Returns:
|
||||
SqlQueryResponse with query results or error information.
|
||||
"""
|
||||
# Validate query is SELECT only
|
||||
if not self._validate_select_query(query=query):
|
||||
return SqlQueryResponse(
|
||||
success=False,
|
||||
data=[],
|
||||
columns=[],
|
||||
row_count=0,
|
||||
message="Only SELECT queries are allowed",
|
||||
)
|
||||
|
||||
try:
|
||||
async with self.engine.begin() as conn:
|
||||
result = await conn.execute(text(query))
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert rows to list of dictionaries
|
||||
columns = list(result.keys()) if rows else []
|
||||
data = [dict(row._mapping) for row in rows]
|
||||
|
||||
return SqlQueryResponse(
|
||||
success=True,
|
||||
data=data,
|
||||
columns=columns,
|
||||
row_count=len(data),
|
||||
message=f"Query executed successfully. {len(data)} rows returned.",
|
||||
)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error executing query: {e}")
|
||||
return SqlQueryResponse(
|
||||
success=False,
|
||||
data=[],
|
||||
columns=[],
|
||||
row_count=0,
|
||||
message=f"Database error: {str(e)}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error executing query: {e}")
|
||||
return SqlQueryResponse(
|
||||
success=False,
|
||||
data=[],
|
||||
columns=[],
|
||||
row_count=0,
|
||||
message=f"Unexpected error: {str(e)}",
|
||||
)
|
||||
|
||||
async def get_database_schema(self) -> DatabaseSchemaResponse:
|
||||
"""Get information about all tables in the database.
|
||||
|
||||
Returns:
|
||||
DatabaseSchemaResponse with table information.
|
||||
"""
|
||||
try:
|
||||
async with self.engine.begin() as conn:
|
||||
# Get table names using SQLAlchemy inspector
|
||||
inspector = inspect(conn.sync_engine)
|
||||
table_names = inspector.get_table_names()
|
||||
|
||||
tables = []
|
||||
for table_name in table_names:
|
||||
# Get row count for each table
|
||||
result = await conn.execute(
|
||||
text(f"SELECT COUNT(*) as count FROM {table_name}")
|
||||
)
|
||||
row_count = result.scalar()
|
||||
|
||||
tables.append(TableInfo(name=table_name, row_count=row_count or 0))
|
||||
|
||||
return DatabaseSchemaResponse(
|
||||
success=True, tables=tables, table_count=len(tables)
|
||||
)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting schema: {e}")
|
||||
return DatabaseSchemaResponse(success=False, tables=[], table_count=0)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting schema: {e}")
|
||||
return DatabaseSchemaResponse(success=False, tables=[], table_count=0)
|
||||
|
||||
async def get_table_schema(self, *, table_name: str) -> TableSchemaResponse:
|
||||
"""Get detailed information about a specific table.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to inspect.
|
||||
|
||||
Returns:
|
||||
TableSchemaResponse with table structure and sample data.
|
||||
"""
|
||||
try:
|
||||
async with self.engine.begin() as conn:
|
||||
# Check if table exists
|
||||
inspector = inspect(conn.sync_engine)
|
||||
table_names = inspector.get_table_names()
|
||||
|
||||
if table_name not in table_names:
|
||||
return TableSchemaResponse(
|
||||
success=False,
|
||||
table_name=table_name,
|
||||
columns=[],
|
||||
row_count=0,
|
||||
sample_data=[],
|
||||
)
|
||||
|
||||
# Get column information
|
||||
columns_info = inspector.get_columns(table_name)
|
||||
pk_constraint = inspector.get_pk_constraint(table_name)
|
||||
primary_keys = pk_constraint.get("constrained_columns", [])
|
||||
|
||||
columns = []
|
||||
for col_info in columns_info:
|
||||
columns.append(
|
||||
ColumnInfo(
|
||||
name=col_info["name"],
|
||||
type=str(col_info["type"]),
|
||||
nullable=col_info["nullable"],
|
||||
primary_key=col_info["name"] in primary_keys,
|
||||
)
|
||||
)
|
||||
|
||||
# Get row count
|
||||
result = await conn.execute(
|
||||
text(f"SELECT COUNT(*) as count FROM {table_name}")
|
||||
)
|
||||
row_count = result.scalar() or 0
|
||||
|
||||
# Get sample data (up to 5 rows)
|
||||
sample_result = await conn.execute(
|
||||
text(f"SELECT * FROM {table_name} LIMIT 5")
|
||||
)
|
||||
sample_rows = sample_result.fetchall()
|
||||
sample_data = [dict(row._mapping) for row in sample_rows]
|
||||
|
||||
return TableSchemaResponse(
|
||||
success=True,
|
||||
table_name=table_name,
|
||||
columns=columns,
|
||||
row_count=row_count,
|
||||
sample_data=sample_data,
|
||||
)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting table schema for {table_name}: {e}")
|
||||
return TableSchemaResponse(
|
||||
success=False,
|
||||
table_name=table_name,
|
||||
columns=[],
|
||||
row_count=0,
|
||||
sample_data=[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting table schema for {table_name}: {e}")
|
||||
return TableSchemaResponse(
|
||||
success=False,
|
||||
table_name=table_name,
|
||||
columns=[],
|
||||
row_count=0,
|
||||
sample_data=[],
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the database engine."""
|
||||
await self.engine.dispose()
|
||||
|
|
@ -30,3 +30,27 @@ def require_pp_api_key(*, api_key: str = api_key_header) -> None:
|
|||
detail="Invalid or missing API key",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
||||
|
||||
# API key header for database endpoint authentication
|
||||
db_api_key_header: SecurityBase = APIKeyHeader(
|
||||
name="X-DB-API-Key",
|
||||
description="API key for database query access",
|
||||
)
|
||||
|
||||
|
||||
def require_db_api_key(*, api_key: str = db_api_key_header) -> None:
|
||||
"""Validate the database API key.
|
||||
|
||||
Args:
|
||||
api_key: The API key from the X-DB-API-Key header.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the API key is invalid or missing.
|
||||
"""
|
||||
if api_key != settings.DB_ENDPOINT_API_KEY:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing database API key",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
|
|
|||
10
src/main.py
10
src/main.py
|
|
@ -4,6 +4,9 @@ import logging
|
|||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from src.dataprocessor.router import router as dataprocessor_router
|
||||
from src.dataquery.router import router as dataquery_router
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -26,6 +29,13 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
|
||||
# Include routers
|
||||
app.include_router(
|
||||
dataprocessor_router, prefix="/api/v1/dataprocessor", tags=["dataprocessor"]
|
||||
)
|
||||
app.include_router(dataquery_router, prefix="/api/v1/dataquery", tags=["dataquery"])
|
||||
|
||||
|
||||
# Add a simple root endpoint
|
||||
@app.get("/")
|
||||
async def root():
|
||||
|
|
|
|||
|
|
@ -7,10 +7,7 @@ from pydantic_settings import BaseSettings
|
|||
class Settings(BaseSettings):
|
||||
"""Application settings."""
|
||||
|
||||
# Preprocessor API key to access this app.
|
||||
PP_API_KEY: str = Field(
|
||||
..., description="API key to access this app for preprocessing."
|
||||
)
|
||||
# --- General Settings ---
|
||||
|
||||
# Preprocessing configuration file path.
|
||||
PP_CONFIG_PATH: str = Field(
|
||||
|
|
@ -24,6 +21,18 @@ class Settings(BaseSettings):
|
|||
"data/database.sqlite", description="Path to the SQLite database."
|
||||
)
|
||||
|
||||
# --- API Keys ---
|
||||
|
||||
# Preprocessor API key to access this app.
|
||||
PP_API_KEY: str = Field(
|
||||
..., description="API key to access this app for preprocessing."
|
||||
)
|
||||
|
||||
# API key needed to access the endpoint that queries the database.
|
||||
DB_ENDPOINT_API_KEY: str = Field(
|
||||
..., description="API key needed to access the database query endpoint."
|
||||
)
|
||||
|
||||
# --- Power BI Settings ---
|
||||
|
||||
# Power BI base URL.
|
||||
|
|
|
|||
Loading…
Reference in a new issue