chore: enforce query row limit
This commit is contained in:
parent
e7dd3ea999
commit
dde61f447d
3 changed files with 56 additions and 0 deletions
|
|
@ -71,6 +71,50 @@ class DataQueryService:
|
|||
|
||||
return True
|
||||
|
||||
def _enforce_query_limit(self, *, query: str) -> str:
|
||||
"""Enforce maximum row limit on query.
|
||||
|
||||
If query has LIMIT > SQL_ROW_LIMIT, replace with SQL_ROW_LIMIT.
|
||||
If query has no LIMIT, append LIMIT SQL_ROW_LIMIT.
|
||||
If query has LIMIT <= SQL_ROW_LIMIT, keep as is.
|
||||
|
||||
Args:
|
||||
query: The SQL query to enforce limit on.
|
||||
|
||||
Returns:
|
||||
Query with enforced limit.
|
||||
"""
|
||||
max_limit = settings.SQL_ROW_LIMIT
|
||||
|
||||
# Remove comments and normalize whitespace for parsing
|
||||
cleaned_query = re.sub(r"--.*$", "", query, flags=re.MULTILINE)
|
||||
cleaned_query = re.sub(r"/\*.*?\*/", "", cleaned_query, flags=re.DOTALL)
|
||||
|
||||
# Look for LIMIT clause (case insensitive)
|
||||
# Pattern matches: LIMIT <number> or LIMIT <number> OFFSET <number>
|
||||
limit_pattern = r"\bLIMIT\s+(\d+)(\s+OFFSET\s+\d+)?\s*$"
|
||||
match = re.search(limit_pattern, cleaned_query, re.IGNORECASE)
|
||||
|
||||
if match:
|
||||
# Extract the current limit value
|
||||
current_limit = int(match.group(1))
|
||||
|
||||
if current_limit > max_limit:
|
||||
# Replace with max_limit while preserving OFFSET if present
|
||||
offset_clause = match.group(2) or ""
|
||||
# Find the position in the original query to replace
|
||||
# Use the original query to preserve formatting
|
||||
original_match = re.search(limit_pattern, query, re.IGNORECASE)
|
||||
if original_match:
|
||||
new_limit_clause = f"LIMIT {max_limit}{offset_clause}"
|
||||
query = query[: original_match.start()] + new_limit_clause
|
||||
# If current_limit <= max_limit, keep query as is
|
||||
else:
|
||||
# No LIMIT clause found, append one
|
||||
query = f"{query.rstrip()} LIMIT {max_limit}"
|
||||
|
||||
return query
|
||||
|
||||
async def execute_query(self, *, query: str) -> SqlQueryResponse:
|
||||
"""Execute a SELECT query and return the results.
|
||||
|
||||
|
|
@ -90,6 +134,9 @@ class DataQueryService:
|
|||
message="Only SELECT queries are allowed",
|
||||
)
|
||||
|
||||
# Enforce row limit on the query
|
||||
query = self._enforce_query_limit(query=query)
|
||||
|
||||
try:
|
||||
async with self.engine.begin() as conn:
|
||||
result = await conn.execute(text(query))
|
||||
|
|
|
|||
|
|
@ -30,6 +30,14 @@ class Settings(BaseSettings):
|
|||
description="Path to the SQLite database.",
|
||||
)
|
||||
|
||||
# --- Database Query Settings ---
|
||||
|
||||
# Maximum number of rows to return from SQL queries.
|
||||
SQL_ROW_LIMIT: int = Field(
|
||||
default=50,
|
||||
description="Maximum number of rows to return from SQL queries. Defaults to 50.",
|
||||
)
|
||||
|
||||
# --- API Keys ---
|
||||
|
||||
# Preprocessor API key to access this app.
|
||||
|
|
|
|||
1
tests/dataquery/__init__.py
Normal file
1
tests/dataquery/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Data query tests."""
|
||||
Loading…
Reference in a new issue