feat: add db query router

This commit is contained in:
Christopher Gondek 2025-09-26 17:07:42 +02:00
parent ed8dd344d7
commit 9bc7cd774c
11 changed files with 449 additions and 27 deletions

View file

@ -0,0 +1 @@
# noqa

View file

@ -0,0 +1 @@
# noqa

View file

@ -1,11 +1,9 @@
from msal import ConfidentialClientApplication, SerializableTokenCache from msal import ConfidentialClientApplication, SerializableTokenCache
import anyio # comes with FastAPI via Starlette/AnyIO import anyio # comes with FastAPI via Starlette/AnyIO
from dataclasses import dataclass from dataclasses import dataclass
from settings import settings from src.settings import settings
import pandas as pd import pandas as pd
import httpx import httpx
import re
import sqlite3
@dataclass @dataclass
@ -76,24 +74,6 @@ class PowerBIReader:
df.columns = [_strip_qual(c) for c in df.columns] df.columns = [_strip_qual(c) for c in df.columns]
return df return df
async def to_sqlite(
self, db_path: str, table_name: str | None = None, if_exists: str = "replace"
) -> int:
"""
Reads from Power BI and writes into SQLite. Returns row count written.
"""
df = await self.read_data()
if df.empty:
return 0
# Normalize SQLite table name
tn = table_name or self.table_name
tn = re.sub(r"[^\w]+", "_", tn).lower().strip("_")
with sqlite3.connect(db_path) as conn:
df.to_sql(tn, conn, if_exists=if_exists, index=False)
return len(df)
@staticmethod @staticmethod
def _get_access_token_sync( def _get_access_token_sync(
tenant_id: str, tenant_id: str,

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from src.settings import settings from src.settings import settings
from src.dataprocessor.domain.powerbi_reader import PowerBIReader from src.dataprocessor.domain.powerbi_reader import PowerBIReader
from src.dataprocessor.domain.preprocessor import Preprocessor from src.dataprocessor.domain.preprocessor import Preprocessor
from dataprocessor.domain.base_datasaver import BaseDataSaver from src.dataprocessor.domain.base_datasaver import BaseDataSaver
from src.dataprocessor.domain.sqlite_datasaver import SQLiteDataSaver from src.dataprocessor.domain.sqlite_datasaver import SQLiteDataSaver

View file

@ -0,0 +1 @@
# noqa

View file

@ -1 +1,85 @@
# TODO: Endpoint for SQL queries (read-only) """Router for data query endpoints."""
import logging
from fastapi import APIRouter, Depends, Path
from src.dataquery.schemas import (
DatabaseSchemaResponse,
SqlQueryRequest,
SqlQueryResponse,
TableSchemaResponse,
)
from src.dataquery.service import DataQueryService
from src.dependencies import require_db_api_key
router = APIRouter()
# Set up logging
logger = logging.getLogger(__name__)
@router.post("/query", response_model=SqlQueryResponse)
async def execute_sql_query(
request: SqlQueryRequest,
_: None = Depends(require_db_api_key),
) -> SqlQueryResponse:
"""Execute a SELECT SQL query against the database.
This endpoint only allows SELECT queries for security reasons.
All other SQL operations (INSERT, UPDATE, DELETE, etc.) are blocked.
Args:
request: The SQL query request containing the query string.
Returns:
SqlQueryResponse with query results or error information.
"""
service = await DataQueryService.create()
try:
result = await service.execute_query(query=request.query)
return result
finally:
await service.close()
@router.get("/schema", response_model=DatabaseSchemaResponse)
async def get_database_schema(
_: None = Depends(require_db_api_key),
) -> DatabaseSchemaResponse:
"""Get information about all tables in the database.
Returns a list of all tables with their names and row counts.
Returns:
DatabaseSchemaResponse with table information.
"""
service = await DataQueryService.create()
try:
result = await service.get_database_schema()
return result
finally:
await service.close()
@router.get("/schema/{table_name}", response_model=TableSchemaResponse)
async def get_table_schema(
table_name: str = Path(..., description="Name of the table to inspect"),
_: None = Depends(require_db_api_key),
) -> TableSchemaResponse:
"""Get detailed information about a specific table.
Returns column information, data types, constraints, and sample data
for the specified table.
Args:
table_name: Name of the table to inspect.
Returns:
TableSchemaResponse with table structure and sample data.
"""
service = await DataQueryService.create()
try:
result = await service.get_table_schema(table_name=table_name)
return result
finally:
await service.close()

View file

@ -0,0 +1,68 @@
"""Schemas for the data query service."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class SqlQueryRequest(BaseModel):
"""Request schema for SQL queries."""
query: str = Field(
..., description="SQL query to execute (SELECT statements only)", min_length=1
)
class SqlQueryResponse(BaseModel):
"""Response schema for SQL query results."""
success: bool = Field(
..., description="Indicates if the query was executed successfully"
)
data: List[Dict[str, Any]] = Field(
..., description="Query results as list of dictionaries"
)
columns: List[str] = Field(..., description="Column names in the result set")
row_count: int = Field(..., description="Number of rows returned")
message: Optional[str] = Field(
None, description="Additional information or error message"
)
class TableInfo(BaseModel):
"""Information about a database table."""
name: str = Field(..., description="Table name")
row_count: int = Field(..., description="Number of rows in the table")
class DatabaseSchemaResponse(BaseModel):
"""Response schema for database schema information."""
success: bool = Field(
..., description="Indicates if the schema retrieval was successful"
)
tables: List[TableInfo] = Field(..., description="List of tables in the database")
table_count: int = Field(..., description="Total number of tables")
class ColumnInfo(BaseModel):
"""Information about a table column."""
name: str = Field(..., description="Column name")
type: str = Field(..., description="Column data type")
nullable: bool = Field(..., description="Whether the column can be null")
primary_key: bool = Field(..., description="Whether the column is a primary key")
class TableSchemaResponse(BaseModel):
"""Response schema for table structure information."""
success: bool = Field(
..., description="Indicates if the table schema retrieval was successful"
)
table_name: str = Field(..., description="Name of the table")
columns: List[ColumnInfo] = Field(..., description="List of columns in the table")
row_count: int = Field(..., description="Number of rows in the table")
sample_data: List[Dict[str, Any]] = Field(
..., description="Sample rows from the table (up to 5 rows)"
)

244
src/dataquery/service.py Normal file
View file

@ -0,0 +1,244 @@
"""Service for handling database queries and schema operations."""
import re
import logging
from dataclasses import dataclass
from sqlalchemy import text, inspect
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
from sqlalchemy.exc import SQLAlchemyError
from src.dataquery.schemas import (
ColumnInfo,
DatabaseSchemaResponse,
SqlQueryResponse,
TableInfo,
TableSchemaResponse,
)
from src.settings import settings
logger = logging.getLogger(__name__)
@dataclass
class DataQueryService:
"""Service for executing database queries and retrieving schema information."""
engine: AsyncEngine
@classmethod
async def create(cls) -> "DataQueryService":
"""Create a new instance of DataQueryService."""
db_url = f"sqlite+aiosqlite:///{settings.DB_PATH}"
engine = create_async_engine(db_url, future=True)
return cls(engine=engine)
def _validate_select_query(self, *, query: str) -> bool:
"""Validate that the query is a SELECT statement only.
Args:
query: The SQL query to validate.
Returns:
True if the query is valid (SELECT only), False otherwise.
"""
# Remove comments and normalize whitespace
cleaned_query = re.sub(r"--.*$", "", query, flags=re.MULTILINE)
cleaned_query = re.sub(r"/\*.*?\*/", "", cleaned_query, flags=re.DOTALL)
cleaned_query = re.sub(r"\s+", " ", cleaned_query.strip())
# Check if query starts with SELECT (case insensitive)
if not re.match(r"^\s*SELECT\s", cleaned_query, re.IGNORECASE):
return False
# Check for dangerous keywords that shouldn't be in SELECT queries
dangerous_keywords = [
r"\bINSERT\b",
r"\bUPDATE\b",
r"\bDELETE\b",
r"\bDROP\b",
r"\bCREATE\b",
r"\bALTER\b",
r"\bTRUNCATE\b",
r"\bREPLACE\b",
r"\bEXEC\b",
r"\bEXECUTE\b",
]
for keyword in dangerous_keywords:
if re.search(keyword, cleaned_query, re.IGNORECASE):
return False
return True
async def execute_query(self, *, query: str) -> SqlQueryResponse:
"""Execute a SELECT query and return the results.
Args:
query: The SQL query to execute.
Returns:
SqlQueryResponse with query results or error information.
"""
# Validate query is SELECT only
if not self._validate_select_query(query=query):
return SqlQueryResponse(
success=False,
data=[],
columns=[],
row_count=0,
message="Only SELECT queries are allowed",
)
try:
async with self.engine.begin() as conn:
result = await conn.execute(text(query))
rows = result.fetchall()
# Convert rows to list of dictionaries
columns = list(result.keys()) if rows else []
data = [dict(row._mapping) for row in rows]
return SqlQueryResponse(
success=True,
data=data,
columns=columns,
row_count=len(data),
message=f"Query executed successfully. {len(data)} rows returned.",
)
except SQLAlchemyError as e:
logger.error(f"Database error executing query: {e}")
return SqlQueryResponse(
success=False,
data=[],
columns=[],
row_count=0,
message=f"Database error: {str(e)}",
)
except Exception as e:
logger.error(f"Unexpected error executing query: {e}")
return SqlQueryResponse(
success=False,
data=[],
columns=[],
row_count=0,
message=f"Unexpected error: {str(e)}",
)
async def get_database_schema(self) -> DatabaseSchemaResponse:
"""Get information about all tables in the database.
Returns:
DatabaseSchemaResponse with table information.
"""
try:
async with self.engine.begin() as conn:
# Get table names using SQLAlchemy inspector
inspector = inspect(conn.sync_engine)
table_names = inspector.get_table_names()
tables = []
for table_name in table_names:
# Get row count for each table
result = await conn.execute(
text(f"SELECT COUNT(*) as count FROM {table_name}")
)
row_count = result.scalar()
tables.append(TableInfo(name=table_name, row_count=row_count or 0))
return DatabaseSchemaResponse(
success=True, tables=tables, table_count=len(tables)
)
except SQLAlchemyError as e:
logger.error(f"Database error getting schema: {e}")
return DatabaseSchemaResponse(success=False, tables=[], table_count=0)
except Exception as e:
logger.error(f"Unexpected error getting schema: {e}")
return DatabaseSchemaResponse(success=False, tables=[], table_count=0)
async def get_table_schema(self, *, table_name: str) -> TableSchemaResponse:
"""Get detailed information about a specific table.
Args:
table_name: Name of the table to inspect.
Returns:
TableSchemaResponse with table structure and sample data.
"""
try:
async with self.engine.begin() as conn:
# Check if table exists
inspector = inspect(conn.sync_engine)
table_names = inspector.get_table_names()
if table_name not in table_names:
return TableSchemaResponse(
success=False,
table_name=table_name,
columns=[],
row_count=0,
sample_data=[],
)
# Get column information
columns_info = inspector.get_columns(table_name)
pk_constraint = inspector.get_pk_constraint(table_name)
primary_keys = pk_constraint.get("constrained_columns", [])
columns = []
for col_info in columns_info:
columns.append(
ColumnInfo(
name=col_info["name"],
type=str(col_info["type"]),
nullable=col_info["nullable"],
primary_key=col_info["name"] in primary_keys,
)
)
# Get row count
result = await conn.execute(
text(f"SELECT COUNT(*) as count FROM {table_name}")
)
row_count = result.scalar() or 0
# Get sample data (up to 5 rows)
sample_result = await conn.execute(
text(f"SELECT * FROM {table_name} LIMIT 5")
)
sample_rows = sample_result.fetchall()
sample_data = [dict(row._mapping) for row in sample_rows]
return TableSchemaResponse(
success=True,
table_name=table_name,
columns=columns,
row_count=row_count,
sample_data=sample_data,
)
except SQLAlchemyError as e:
logger.error(f"Database error getting table schema for {table_name}: {e}")
return TableSchemaResponse(
success=False,
table_name=table_name,
columns=[],
row_count=0,
sample_data=[],
)
except Exception as e:
logger.error(f"Unexpected error getting table schema for {table_name}: {e}")
return TableSchemaResponse(
success=False,
table_name=table_name,
columns=[],
row_count=0,
sample_data=[],
)
async def close(self) -> None:
"""Close the database engine."""
await self.engine.dispose()

View file

@ -30,3 +30,27 @@ def require_pp_api_key(*, api_key: str = api_key_header) -> None:
detail="Invalid or missing API key", detail="Invalid or missing API key",
headers={"WWW-Authenticate": "ApiKey"}, headers={"WWW-Authenticate": "ApiKey"},
) )
# API key header for database endpoint authentication
db_api_key_header: SecurityBase = APIKeyHeader(
name="X-DB-API-Key",
description="API key for database query access",
)
def require_db_api_key(*, api_key: str = db_api_key_header) -> None:
"""Validate the database API key.
Args:
api_key: The API key from the X-DB-API-Key header.
Raises:
HTTPException: If the API key is invalid or missing.
"""
if api_key != settings.DB_ENDPOINT_API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing database API key",
headers={"WWW-Authenticate": "ApiKey"},
)

View file

@ -4,6 +4,9 @@ import logging
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from src.dataprocessor.router import router as dataprocessor_router
from src.dataquery.router import router as dataquery_router
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -26,6 +29,13 @@ app.add_middleware(
) )
# Include routers
app.include_router(
dataprocessor_router, prefix="/api/v1/dataprocessor", tags=["dataprocessor"]
)
app.include_router(dataquery_router, prefix="/api/v1/dataquery", tags=["dataquery"])
# Add a simple root endpoint # Add a simple root endpoint
@app.get("/") @app.get("/")
async def root(): async def root():

View file

@ -7,10 +7,7 @@ from pydantic_settings import BaseSettings
class Settings(BaseSettings): class Settings(BaseSettings):
"""Application settings.""" """Application settings."""
# Preprocessor API key to access this app. # --- General Settings ---
PP_API_KEY: str = Field(
..., description="API key to access this app for preprocessing."
)
# Preprocessing configuration file path. # Preprocessing configuration file path.
PP_CONFIG_PATH: str = Field( PP_CONFIG_PATH: str = Field(
@ -24,6 +21,18 @@ class Settings(BaseSettings):
"data/database.sqlite", description="Path to the SQLite database." "data/database.sqlite", description="Path to the SQLite database."
) )
# --- API Keys ---
# Preprocessor API key to access this app.
PP_API_KEY: str = Field(
..., description="API key to access this app for preprocessing."
)
# API key needed to access the endpoint that queries the database.
DB_ENDPOINT_API_KEY: str = Field(
..., description="API key needed to access the database query endpoint."
)
# --- Power BI Settings --- # --- Power BI Settings ---
# Power BI base URL. # Power BI base URL.