feat: add db query router

This commit is contained in:
Christopher Gondek 2025-09-26 17:07:42 +02:00
parent ed8dd344d7
commit 9bc7cd774c
11 changed files with 449 additions and 27 deletions

View file

@ -0,0 +1 @@
# noqa

View file

@ -0,0 +1 @@
# noqa

View file

@ -1,11 +1,9 @@
from msal import ConfidentialClientApplication, SerializableTokenCache
import anyio # comes with FastAPI via Starlette/AnyIO
from dataclasses import dataclass
from settings import settings
from src.settings import settings
import pandas as pd
import httpx
import re
import sqlite3
@dataclass
@ -76,24 +74,6 @@ class PowerBIReader:
df.columns = [_strip_qual(c) for c in df.columns]
return df
async def to_sqlite(
self, db_path: str, table_name: str | None = None, if_exists: str = "replace"
) -> int:
"""
Reads from Power BI and writes into SQLite. Returns row count written.
"""
df = await self.read_data()
if df.empty:
return 0
# Normalize SQLite table name
tn = table_name or self.table_name
tn = re.sub(r"[^\w]+", "_", tn).lower().strip("_")
with sqlite3.connect(db_path) as conn:
df.to_sql(tn, conn, if_exists=if_exists, index=False)
return len(df)
@staticmethod
def _get_access_token_sync(
tenant_id: str,

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from src.settings import settings
from src.dataprocessor.domain.powerbi_reader import PowerBIReader
from src.dataprocessor.domain.preprocessor import Preprocessor
from dataprocessor.domain.base_datasaver import BaseDataSaver
from src.dataprocessor.domain.base_datasaver import BaseDataSaver
from src.dataprocessor.domain.sqlite_datasaver import SQLiteDataSaver

View file

@ -0,0 +1 @@
# noqa

View file

@ -1 +1,85 @@
# TODO: Endpoint for SQL queries (read-only)
"""Router for data query endpoints."""
import logging
from fastapi import APIRouter, Depends, Path
from src.dataquery.schemas import (
DatabaseSchemaResponse,
SqlQueryRequest,
SqlQueryResponse,
TableSchemaResponse,
)
from src.dataquery.service import DataQueryService
from src.dependencies import require_db_api_key
router = APIRouter()
# Set up logging
logger = logging.getLogger(__name__)
@router.post("/query", response_model=SqlQueryResponse)
async def execute_sql_query(
request: SqlQueryRequest,
_: None = Depends(require_db_api_key),
) -> SqlQueryResponse:
"""Execute a SELECT SQL query against the database.
This endpoint only allows SELECT queries for security reasons.
All other SQL operations (INSERT, UPDATE, DELETE, etc.) are blocked.
Args:
request: The SQL query request containing the query string.
Returns:
SqlQueryResponse with query results or error information.
"""
service = await DataQueryService.create()
try:
result = await service.execute_query(query=request.query)
return result
finally:
await service.close()
@router.get("/schema", response_model=DatabaseSchemaResponse)
async def get_database_schema(
_: None = Depends(require_db_api_key),
) -> DatabaseSchemaResponse:
"""Get information about all tables in the database.
Returns a list of all tables with their names and row counts.
Returns:
DatabaseSchemaResponse with table information.
"""
service = await DataQueryService.create()
try:
result = await service.get_database_schema()
return result
finally:
await service.close()
@router.get("/schema/{table_name}", response_model=TableSchemaResponse)
async def get_table_schema(
table_name: str = Path(..., description="Name of the table to inspect"),
_: None = Depends(require_db_api_key),
) -> TableSchemaResponse:
"""Get detailed information about a specific table.
Returns column information, data types, constraints, and sample data
for the specified table.
Args:
table_name: Name of the table to inspect.
Returns:
TableSchemaResponse with table structure and sample data.
"""
service = await DataQueryService.create()
try:
result = await service.get_table_schema(table_name=table_name)
return result
finally:
await service.close()

View file

@ -0,0 +1,68 @@
"""Schemas for the data query service."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class SqlQueryRequest(BaseModel):
"""Request schema for SQL queries."""
query: str = Field(
..., description="SQL query to execute (SELECT statements only)", min_length=1
)
class SqlQueryResponse(BaseModel):
"""Response schema for SQL query results."""
success: bool = Field(
..., description="Indicates if the query was executed successfully"
)
data: List[Dict[str, Any]] = Field(
..., description="Query results as list of dictionaries"
)
columns: List[str] = Field(..., description="Column names in the result set")
row_count: int = Field(..., description="Number of rows returned")
message: Optional[str] = Field(
None, description="Additional information or error message"
)
class TableInfo(BaseModel):
"""Information about a database table."""
name: str = Field(..., description="Table name")
row_count: int = Field(..., description="Number of rows in the table")
class DatabaseSchemaResponse(BaseModel):
"""Response schema for database schema information."""
success: bool = Field(
..., description="Indicates if the schema retrieval was successful"
)
tables: List[TableInfo] = Field(..., description="List of tables in the database")
table_count: int = Field(..., description="Total number of tables")
class ColumnInfo(BaseModel):
"""Information about a table column."""
name: str = Field(..., description="Column name")
type: str = Field(..., description="Column data type")
nullable: bool = Field(..., description="Whether the column can be null")
primary_key: bool = Field(..., description="Whether the column is a primary key")
class TableSchemaResponse(BaseModel):
"""Response schema for table structure information."""
success: bool = Field(
..., description="Indicates if the table schema retrieval was successful"
)
table_name: str = Field(..., description="Name of the table")
columns: List[ColumnInfo] = Field(..., description="List of columns in the table")
row_count: int = Field(..., description="Number of rows in the table")
sample_data: List[Dict[str, Any]] = Field(
..., description="Sample rows from the table (up to 5 rows)"
)

244
src/dataquery/service.py Normal file
View file

@ -0,0 +1,244 @@
"""Service for handling database queries and schema operations."""
import re
import logging
from dataclasses import dataclass
from sqlalchemy import text, inspect
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
from sqlalchemy.exc import SQLAlchemyError
from src.dataquery.schemas import (
ColumnInfo,
DatabaseSchemaResponse,
SqlQueryResponse,
TableInfo,
TableSchemaResponse,
)
from src.settings import settings
logger = logging.getLogger(__name__)
@dataclass
class DataQueryService:
"""Service for executing database queries and retrieving schema information."""
engine: AsyncEngine
@classmethod
async def create(cls) -> "DataQueryService":
"""Create a new instance of DataQueryService."""
db_url = f"sqlite+aiosqlite:///{settings.DB_PATH}"
engine = create_async_engine(db_url, future=True)
return cls(engine=engine)
def _validate_select_query(self, *, query: str) -> bool:
"""Validate that the query is a SELECT statement only.
Args:
query: The SQL query to validate.
Returns:
True if the query is valid (SELECT only), False otherwise.
"""
# Remove comments and normalize whitespace
cleaned_query = re.sub(r"--.*$", "", query, flags=re.MULTILINE)
cleaned_query = re.sub(r"/\*.*?\*/", "", cleaned_query, flags=re.DOTALL)
cleaned_query = re.sub(r"\s+", " ", cleaned_query.strip())
# Check if query starts with SELECT (case insensitive)
if not re.match(r"^\s*SELECT\s", cleaned_query, re.IGNORECASE):
return False
# Check for dangerous keywords that shouldn't be in SELECT queries
dangerous_keywords = [
r"\bINSERT\b",
r"\bUPDATE\b",
r"\bDELETE\b",
r"\bDROP\b",
r"\bCREATE\b",
r"\bALTER\b",
r"\bTRUNCATE\b",
r"\bREPLACE\b",
r"\bEXEC\b",
r"\bEXECUTE\b",
]
for keyword in dangerous_keywords:
if re.search(keyword, cleaned_query, re.IGNORECASE):
return False
return True
async def execute_query(self, *, query: str) -> SqlQueryResponse:
"""Execute a SELECT query and return the results.
Args:
query: The SQL query to execute.
Returns:
SqlQueryResponse with query results or error information.
"""
# Validate query is SELECT only
if not self._validate_select_query(query=query):
return SqlQueryResponse(
success=False,
data=[],
columns=[],
row_count=0,
message="Only SELECT queries are allowed",
)
try:
async with self.engine.begin() as conn:
result = await conn.execute(text(query))
rows = result.fetchall()
# Convert rows to list of dictionaries
columns = list(result.keys()) if rows else []
data = [dict(row._mapping) for row in rows]
return SqlQueryResponse(
success=True,
data=data,
columns=columns,
row_count=len(data),
message=f"Query executed successfully. {len(data)} rows returned.",
)
except SQLAlchemyError as e:
logger.error(f"Database error executing query: {e}")
return SqlQueryResponse(
success=False,
data=[],
columns=[],
row_count=0,
message=f"Database error: {str(e)}",
)
except Exception as e:
logger.error(f"Unexpected error executing query: {e}")
return SqlQueryResponse(
success=False,
data=[],
columns=[],
row_count=0,
message=f"Unexpected error: {str(e)}",
)
async def get_database_schema(self) -> DatabaseSchemaResponse:
"""Get information about all tables in the database.
Returns:
DatabaseSchemaResponse with table information.
"""
try:
async with self.engine.begin() as conn:
# Get table names using SQLAlchemy inspector
inspector = inspect(conn.sync_engine)
table_names = inspector.get_table_names()
tables = []
for table_name in table_names:
# Get row count for each table
result = await conn.execute(
text(f"SELECT COUNT(*) as count FROM {table_name}")
)
row_count = result.scalar()
tables.append(TableInfo(name=table_name, row_count=row_count or 0))
return DatabaseSchemaResponse(
success=True, tables=tables, table_count=len(tables)
)
except SQLAlchemyError as e:
logger.error(f"Database error getting schema: {e}")
return DatabaseSchemaResponse(success=False, tables=[], table_count=0)
except Exception as e:
logger.error(f"Unexpected error getting schema: {e}")
return DatabaseSchemaResponse(success=False, tables=[], table_count=0)
async def get_table_schema(self, *, table_name: str) -> TableSchemaResponse:
"""Get detailed information about a specific table.
Args:
table_name: Name of the table to inspect.
Returns:
TableSchemaResponse with table structure and sample data.
"""
try:
async with self.engine.begin() as conn:
# Check if table exists
inspector = inspect(conn.sync_engine)
table_names = inspector.get_table_names()
if table_name not in table_names:
return TableSchemaResponse(
success=False,
table_name=table_name,
columns=[],
row_count=0,
sample_data=[],
)
# Get column information
columns_info = inspector.get_columns(table_name)
pk_constraint = inspector.get_pk_constraint(table_name)
primary_keys = pk_constraint.get("constrained_columns", [])
columns = []
for col_info in columns_info:
columns.append(
ColumnInfo(
name=col_info["name"],
type=str(col_info["type"]),
nullable=col_info["nullable"],
primary_key=col_info["name"] in primary_keys,
)
)
# Get row count
result = await conn.execute(
text(f"SELECT COUNT(*) as count FROM {table_name}")
)
row_count = result.scalar() or 0
# Get sample data (up to 5 rows)
sample_result = await conn.execute(
text(f"SELECT * FROM {table_name} LIMIT 5")
)
sample_rows = sample_result.fetchall()
sample_data = [dict(row._mapping) for row in sample_rows]
return TableSchemaResponse(
success=True,
table_name=table_name,
columns=columns,
row_count=row_count,
sample_data=sample_data,
)
except SQLAlchemyError as e:
logger.error(f"Database error getting table schema for {table_name}: {e}")
return TableSchemaResponse(
success=False,
table_name=table_name,
columns=[],
row_count=0,
sample_data=[],
)
except Exception as e:
logger.error(f"Unexpected error getting table schema for {table_name}: {e}")
return TableSchemaResponse(
success=False,
table_name=table_name,
columns=[],
row_count=0,
sample_data=[],
)
async def close(self) -> None:
"""Close the database engine."""
await self.engine.dispose()

View file

@ -30,3 +30,27 @@ def require_pp_api_key(*, api_key: str = api_key_header) -> None:
detail="Invalid or missing API key",
headers={"WWW-Authenticate": "ApiKey"},
)
# API key header for database endpoint authentication
db_api_key_header: SecurityBase = APIKeyHeader(
name="X-DB-API-Key",
description="API key for database query access",
)
def require_db_api_key(*, api_key: str = db_api_key_header) -> None:
"""Validate the database API key.
Args:
api_key: The API key from the X-DB-API-Key header.
Raises:
HTTPException: If the API key is invalid or missing.
"""
if api_key != settings.DB_ENDPOINT_API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing database API key",
headers={"WWW-Authenticate": "ApiKey"},
)

View file

@ -4,6 +4,9 @@ import logging
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from src.dataprocessor.router import router as dataprocessor_router
from src.dataquery.router import router as dataquery_router
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@ -26,6 +29,13 @@ app.add_middleware(
)
# Include routers
app.include_router(
dataprocessor_router, prefix="/api/v1/dataprocessor", tags=["dataprocessor"]
)
app.include_router(dataquery_router, prefix="/api/v1/dataquery", tags=["dataquery"])
# Add a simple root endpoint
@app.get("/")
async def root():

View file

@ -7,10 +7,7 @@ from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings."""
# Preprocessor API key to access this app.
PP_API_KEY: str = Field(
..., description="API key to access this app for preprocessing."
)
# --- General Settings ---
# Preprocessing configuration file path.
PP_CONFIG_PATH: str = Field(
@ -24,6 +21,18 @@ class Settings(BaseSettings):
"data/database.sqlite", description="Path to the SQLite database."
)
# --- API Keys ---
# Preprocessor API key to access this app.
PP_API_KEY: str = Field(
..., description="API key to access this app for preprocessing."
)
# API key needed to access the endpoint that queries the database.
DB_ENDPOINT_API_KEY: str = Field(
..., description="API key needed to access the database query endpoint."
)
# --- Power BI Settings ---
# Power BI base URL.