From 9bc7cd774c2b1558ca46389a63093b55d3b92bd1 Mon Sep 17 00:00:00 2001 From: Christopher Gondek Date: Fri, 26 Sep 2025 17:07:42 +0200 Subject: [PATCH] feat: add db query router --- src/dataprocessor/__init__.py | 1 + src/dataprocessor/domain/__init__.py | 1 + src/dataprocessor/domain/powerbi_reader.py | 22 +- src/dataprocessor/service.py | 2 +- src/dataquery/__init__.py | 1 + src/dataquery/router.py | 86 +++++++- src/dataquery/schemas.py | 68 ++++++ src/dataquery/service.py | 244 +++++++++++++++++++++ src/dependencies.py | 24 ++ src/main.py | 10 + src/settings.py | 17 +- 11 files changed, 449 insertions(+), 27 deletions(-) create mode 100644 src/dataprocessor/__init__.py create mode 100644 src/dataprocessor/domain/__init__.py create mode 100644 src/dataquery/__init__.py create mode 100644 src/dataquery/service.py diff --git a/src/dataprocessor/__init__.py b/src/dataprocessor/__init__.py new file mode 100644 index 0000000..4ede8e6 --- /dev/null +++ b/src/dataprocessor/__init__.py @@ -0,0 +1 @@ +# noqa diff --git a/src/dataprocessor/domain/__init__.py b/src/dataprocessor/domain/__init__.py new file mode 100644 index 0000000..4ede8e6 --- /dev/null +++ b/src/dataprocessor/domain/__init__.py @@ -0,0 +1 @@ +# noqa diff --git a/src/dataprocessor/domain/powerbi_reader.py b/src/dataprocessor/domain/powerbi_reader.py index 6dc65f3..9f02a10 100644 --- a/src/dataprocessor/domain/powerbi_reader.py +++ b/src/dataprocessor/domain/powerbi_reader.py @@ -1,11 +1,9 @@ from msal import ConfidentialClientApplication, SerializableTokenCache import anyio # comes with FastAPI via Starlette/AnyIO from dataclasses import dataclass -from settings import settings +from src.settings import settings import pandas as pd import httpx -import re -import sqlite3 @dataclass @@ -76,24 +74,6 @@ class PowerBIReader: df.columns = [_strip_qual(c) for c in df.columns] return df - async def to_sqlite( - self, db_path: str, table_name: str | None = None, if_exists: str = "replace" - ) -> int: - """ - Reads from Power BI and writes into SQLite. Returns row count written. - """ - df = await self.read_data() - if df.empty: - return 0 - - # Normalize SQLite table name - tn = table_name or self.table_name - tn = re.sub(r"[^\w]+", "_", tn).lower().strip("_") - - with sqlite3.connect(db_path) as conn: - df.to_sql(tn, conn, if_exists=if_exists, index=False) - return len(df) - @staticmethod def _get_access_token_sync( tenant_id: str, diff --git a/src/dataprocessor/service.py b/src/dataprocessor/service.py index 32f93c9..ac9082a 100644 --- a/src/dataprocessor/service.py +++ b/src/dataprocessor/service.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from src.settings import settings from src.dataprocessor.domain.powerbi_reader import PowerBIReader from src.dataprocessor.domain.preprocessor import Preprocessor -from dataprocessor.domain.base_datasaver import BaseDataSaver +from src.dataprocessor.domain.base_datasaver import BaseDataSaver from src.dataprocessor.domain.sqlite_datasaver import SQLiteDataSaver diff --git a/src/dataquery/__init__.py b/src/dataquery/__init__.py new file mode 100644 index 0000000..4ede8e6 --- /dev/null +++ b/src/dataquery/__init__.py @@ -0,0 +1 @@ +# noqa diff --git a/src/dataquery/router.py b/src/dataquery/router.py index e006d8d..ee927cd 100644 --- a/src/dataquery/router.py +++ b/src/dataquery/router.py @@ -1 +1,85 @@ -# TODO: Endpoint for SQL queries (read-only) +"""Router for data query endpoints.""" + +import logging +from fastapi import APIRouter, Depends, Path + +from src.dataquery.schemas import ( + DatabaseSchemaResponse, + SqlQueryRequest, + SqlQueryResponse, + TableSchemaResponse, +) +from src.dataquery.service import DataQueryService +from src.dependencies import require_db_api_key + +router = APIRouter() + +# Set up logging +logger = logging.getLogger(__name__) + + +@router.post("/query", response_model=SqlQueryResponse) +async def execute_sql_query( + request: SqlQueryRequest, + _: None = Depends(require_db_api_key), +) -> SqlQueryResponse: + """Execute a SELECT SQL query against the database. + + This endpoint only allows SELECT queries for security reasons. + All other SQL operations (INSERT, UPDATE, DELETE, etc.) are blocked. + + Args: + request: The SQL query request containing the query string. + + Returns: + SqlQueryResponse with query results or error information. + """ + service = await DataQueryService.create() + try: + result = await service.execute_query(query=request.query) + return result + finally: + await service.close() + + +@router.get("/schema", response_model=DatabaseSchemaResponse) +async def get_database_schema( + _: None = Depends(require_db_api_key), +) -> DatabaseSchemaResponse: + """Get information about all tables in the database. + + Returns a list of all tables with their names and row counts. + + Returns: + DatabaseSchemaResponse with table information. + """ + service = await DataQueryService.create() + try: + result = await service.get_database_schema() + return result + finally: + await service.close() + + +@router.get("/schema/{table_name}", response_model=TableSchemaResponse) +async def get_table_schema( + table_name: str = Path(..., description="Name of the table to inspect"), + _: None = Depends(require_db_api_key), +) -> TableSchemaResponse: + """Get detailed information about a specific table. + + Returns column information, data types, constraints, and sample data + for the specified table. + + Args: + table_name: Name of the table to inspect. + + Returns: + TableSchemaResponse with table structure and sample data. + """ + service = await DataQueryService.create() + try: + result = await service.get_table_schema(table_name=table_name) + return result + finally: + await service.close() diff --git a/src/dataquery/schemas.py b/src/dataquery/schemas.py index e69de29..ff52cea 100644 --- a/src/dataquery/schemas.py +++ b/src/dataquery/schemas.py @@ -0,0 +1,68 @@ +"""Schemas for the data query service.""" + +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, Field + + +class SqlQueryRequest(BaseModel): + """Request schema for SQL queries.""" + + query: str = Field( + ..., description="SQL query to execute (SELECT statements only)", min_length=1 + ) + + +class SqlQueryResponse(BaseModel): + """Response schema for SQL query results.""" + + success: bool = Field( + ..., description="Indicates if the query was executed successfully" + ) + data: List[Dict[str, Any]] = Field( + ..., description="Query results as list of dictionaries" + ) + columns: List[str] = Field(..., description="Column names in the result set") + row_count: int = Field(..., description="Number of rows returned") + message: Optional[str] = Field( + None, description="Additional information or error message" + ) + + +class TableInfo(BaseModel): + """Information about a database table.""" + + name: str = Field(..., description="Table name") + row_count: int = Field(..., description="Number of rows in the table") + + +class DatabaseSchemaResponse(BaseModel): + """Response schema for database schema information.""" + + success: bool = Field( + ..., description="Indicates if the schema retrieval was successful" + ) + tables: List[TableInfo] = Field(..., description="List of tables in the database") + table_count: int = Field(..., description="Total number of tables") + + +class ColumnInfo(BaseModel): + """Information about a table column.""" + + name: str = Field(..., description="Column name") + type: str = Field(..., description="Column data type") + nullable: bool = Field(..., description="Whether the column can be null") + primary_key: bool = Field(..., description="Whether the column is a primary key") + + +class TableSchemaResponse(BaseModel): + """Response schema for table structure information.""" + + success: bool = Field( + ..., description="Indicates if the table schema retrieval was successful" + ) + table_name: str = Field(..., description="Name of the table") + columns: List[ColumnInfo] = Field(..., description="List of columns in the table") + row_count: int = Field(..., description="Number of rows in the table") + sample_data: List[Dict[str, Any]] = Field( + ..., description="Sample rows from the table (up to 5 rows)" + ) diff --git a/src/dataquery/service.py b/src/dataquery/service.py new file mode 100644 index 0000000..c77513f --- /dev/null +++ b/src/dataquery/service.py @@ -0,0 +1,244 @@ +"""Service for handling database queries and schema operations.""" + +import re +import logging +from dataclasses import dataclass + +from sqlalchemy import text, inspect +from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine +from sqlalchemy.exc import SQLAlchemyError + +from src.dataquery.schemas import ( + ColumnInfo, + DatabaseSchemaResponse, + SqlQueryResponse, + TableInfo, + TableSchemaResponse, +) +from src.settings import settings + +logger = logging.getLogger(__name__) + + +@dataclass +class DataQueryService: + """Service for executing database queries and retrieving schema information.""" + + engine: AsyncEngine + + @classmethod + async def create(cls) -> "DataQueryService": + """Create a new instance of DataQueryService.""" + db_url = f"sqlite+aiosqlite:///{settings.DB_PATH}" + engine = create_async_engine(db_url, future=True) + return cls(engine=engine) + + def _validate_select_query(self, *, query: str) -> bool: + """Validate that the query is a SELECT statement only. + + Args: + query: The SQL query to validate. + + Returns: + True if the query is valid (SELECT only), False otherwise. + """ + # Remove comments and normalize whitespace + cleaned_query = re.sub(r"--.*$", "", query, flags=re.MULTILINE) + cleaned_query = re.sub(r"/\*.*?\*/", "", cleaned_query, flags=re.DOTALL) + cleaned_query = re.sub(r"\s+", " ", cleaned_query.strip()) + + # Check if query starts with SELECT (case insensitive) + if not re.match(r"^\s*SELECT\s", cleaned_query, re.IGNORECASE): + return False + + # Check for dangerous keywords that shouldn't be in SELECT queries + dangerous_keywords = [ + r"\bINSERT\b", + r"\bUPDATE\b", + r"\bDELETE\b", + r"\bDROP\b", + r"\bCREATE\b", + r"\bALTER\b", + r"\bTRUNCATE\b", + r"\bREPLACE\b", + r"\bEXEC\b", + r"\bEXECUTE\b", + ] + + for keyword in dangerous_keywords: + if re.search(keyword, cleaned_query, re.IGNORECASE): + return False + + return True + + async def execute_query(self, *, query: str) -> SqlQueryResponse: + """Execute a SELECT query and return the results. + + Args: + query: The SQL query to execute. + + Returns: + SqlQueryResponse with query results or error information. + """ + # Validate query is SELECT only + if not self._validate_select_query(query=query): + return SqlQueryResponse( + success=False, + data=[], + columns=[], + row_count=0, + message="Only SELECT queries are allowed", + ) + + try: + async with self.engine.begin() as conn: + result = await conn.execute(text(query)) + rows = result.fetchall() + + # Convert rows to list of dictionaries + columns = list(result.keys()) if rows else [] + data = [dict(row._mapping) for row in rows] + + return SqlQueryResponse( + success=True, + data=data, + columns=columns, + row_count=len(data), + message=f"Query executed successfully. {len(data)} rows returned.", + ) + + except SQLAlchemyError as e: + logger.error(f"Database error executing query: {e}") + return SqlQueryResponse( + success=False, + data=[], + columns=[], + row_count=0, + message=f"Database error: {str(e)}", + ) + except Exception as e: + logger.error(f"Unexpected error executing query: {e}") + return SqlQueryResponse( + success=False, + data=[], + columns=[], + row_count=0, + message=f"Unexpected error: {str(e)}", + ) + + async def get_database_schema(self) -> DatabaseSchemaResponse: + """Get information about all tables in the database. + + Returns: + DatabaseSchemaResponse with table information. + """ + try: + async with self.engine.begin() as conn: + # Get table names using SQLAlchemy inspector + inspector = inspect(conn.sync_engine) + table_names = inspector.get_table_names() + + tables = [] + for table_name in table_names: + # Get row count for each table + result = await conn.execute( + text(f"SELECT COUNT(*) as count FROM {table_name}") + ) + row_count = result.scalar() + + tables.append(TableInfo(name=table_name, row_count=row_count or 0)) + + return DatabaseSchemaResponse( + success=True, tables=tables, table_count=len(tables) + ) + + except SQLAlchemyError as e: + logger.error(f"Database error getting schema: {e}") + return DatabaseSchemaResponse(success=False, tables=[], table_count=0) + except Exception as e: + logger.error(f"Unexpected error getting schema: {e}") + return DatabaseSchemaResponse(success=False, tables=[], table_count=0) + + async def get_table_schema(self, *, table_name: str) -> TableSchemaResponse: + """Get detailed information about a specific table. + + Args: + table_name: Name of the table to inspect. + + Returns: + TableSchemaResponse with table structure and sample data. + """ + try: + async with self.engine.begin() as conn: + # Check if table exists + inspector = inspect(conn.sync_engine) + table_names = inspector.get_table_names() + + if table_name not in table_names: + return TableSchemaResponse( + success=False, + table_name=table_name, + columns=[], + row_count=0, + sample_data=[], + ) + + # Get column information + columns_info = inspector.get_columns(table_name) + pk_constraint = inspector.get_pk_constraint(table_name) + primary_keys = pk_constraint.get("constrained_columns", []) + + columns = [] + for col_info in columns_info: + columns.append( + ColumnInfo( + name=col_info["name"], + type=str(col_info["type"]), + nullable=col_info["nullable"], + primary_key=col_info["name"] in primary_keys, + ) + ) + + # Get row count + result = await conn.execute( + text(f"SELECT COUNT(*) as count FROM {table_name}") + ) + row_count = result.scalar() or 0 + + # Get sample data (up to 5 rows) + sample_result = await conn.execute( + text(f"SELECT * FROM {table_name} LIMIT 5") + ) + sample_rows = sample_result.fetchall() + sample_data = [dict(row._mapping) for row in sample_rows] + + return TableSchemaResponse( + success=True, + table_name=table_name, + columns=columns, + row_count=row_count, + sample_data=sample_data, + ) + + except SQLAlchemyError as e: + logger.error(f"Database error getting table schema for {table_name}: {e}") + return TableSchemaResponse( + success=False, + table_name=table_name, + columns=[], + row_count=0, + sample_data=[], + ) + except Exception as e: + logger.error(f"Unexpected error getting table schema for {table_name}: {e}") + return TableSchemaResponse( + success=False, + table_name=table_name, + columns=[], + row_count=0, + sample_data=[], + ) + + async def close(self) -> None: + """Close the database engine.""" + await self.engine.dispose() diff --git a/src/dependencies.py b/src/dependencies.py index b10b9ed..62305d1 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -30,3 +30,27 @@ def require_pp_api_key(*, api_key: str = api_key_header) -> None: detail="Invalid or missing API key", headers={"WWW-Authenticate": "ApiKey"}, ) + + +# API key header for database endpoint authentication +db_api_key_header: SecurityBase = APIKeyHeader( + name="X-DB-API-Key", + description="API key for database query access", +) + + +def require_db_api_key(*, api_key: str = db_api_key_header) -> None: + """Validate the database API key. + + Args: + api_key: The API key from the X-DB-API-Key header. + + Raises: + HTTPException: If the API key is invalid or missing. + """ + if api_key != settings.DB_ENDPOINT_API_KEY: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing database API key", + headers={"WWW-Authenticate": "ApiKey"}, + ) diff --git a/src/main.py b/src/main.py index 3c60c73..41f6a08 100644 --- a/src/main.py +++ b/src/main.py @@ -4,6 +4,9 @@ import logging from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from src.dataprocessor.router import router as dataprocessor_router +from src.dataquery.router import router as dataquery_router + # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -26,6 +29,13 @@ app.add_middleware( ) +# Include routers +app.include_router( + dataprocessor_router, prefix="/api/v1/dataprocessor", tags=["dataprocessor"] +) +app.include_router(dataquery_router, prefix="/api/v1/dataquery", tags=["dataquery"]) + + # Add a simple root endpoint @app.get("/") async def root(): diff --git a/src/settings.py b/src/settings.py index 9554ca6..5667ba8 100644 --- a/src/settings.py +++ b/src/settings.py @@ -7,10 +7,7 @@ from pydantic_settings import BaseSettings class Settings(BaseSettings): """Application settings.""" - # Preprocessor API key to access this app. - PP_API_KEY: str = Field( - ..., description="API key to access this app for preprocessing." - ) + # --- General Settings --- # Preprocessing configuration file path. PP_CONFIG_PATH: str = Field( @@ -24,6 +21,18 @@ class Settings(BaseSettings): "data/database.sqlite", description="Path to the SQLite database." ) + # --- API Keys --- + + # Preprocessor API key to access this app. + PP_API_KEY: str = Field( + ..., description="API key to access this app for preprocessing." + ) + + # API key needed to access the endpoint that queries the database. + DB_ENDPOINT_API_KEY: str = Field( + ..., description="API key needed to access the database query endpoint." + ) + # --- Power BI Settings --- # Power BI base URL.