service-preprocessing/src/dataprocessor/domain/powerbi_reader.py

259 lines
8.6 KiB
Python

from msal import ConfidentialClientApplication, SerializableTokenCache
import anyio # comes with FastAPI via Starlette/AnyIO
from dataclasses import dataclass
from src.settings import settings
import pandas as pd
import httpx
@dataclass
class PowerBIReader:
dataset_id: str
access_token: str
table_name: str
base_url: str = settings.POWERBI_BASE_URL
include_nulls: bool = True
measures: list[str] = None
group_by_columns: list[str] = None
batch_size: int = 10000
order_by_column: str | None = None
@classmethod
async def create(
cls,
*,
dataset_id: str,
access_token: str,
table_name: str,
measures: list[str] = None,
group_by_columns: list[str] = None,
batch_size: int = 10000,
order_by_column: str | None = None,
**kwargs,
) -> "PowerBIReader":
return cls(
dataset_id=dataset_id,
access_token=access_token,
table_name=table_name,
measures=measures or [],
group_by_columns=group_by_columns or [],
batch_size=batch_size,
order_by_column=order_by_column,
**kwargs,
)
def _dax_query(self) -> str:
"""Generate DAX query based on configuration.
Generates different DAX queries depending on whether measures and/or
group_by_columns are specified:
1. No measures: EVALUATE 'TableName'
Returns all physical/calculated columns from the table.
2. Measures only: EVALUATE ADDCOLUMNS('TableName', "Measure1", [Measure1], ...)
Returns all columns plus the specified measures.
3. Measures + group_by_columns: EVALUATE SUMMARIZECOLUMNS('Table'[Col1], ..., "Measure1", [Measure1], ...)
Returns aggregated measures grouped by the specified columns.
Returns:
DAX query string to execute against Power BI.
"""
# Escape single quotes in table names per DAX rules
safe_table = self.table_name.replace("'", "''")
# Case 1: No measures - simple table evaluation
if not self.measures:
return f"EVALUATE '{safe_table}'"
# Case 2: Measures without grouping - use ADDCOLUMNS
if not self.group_by_columns:
measure_clauses = ", ".join(
[f'"{measure}", [{measure}]' for measure in self.measures]
)
return f"EVALUATE ADDCOLUMNS('{safe_table}', {measure_clauses})"
# Case 3: Measures with grouping - use SUMMARIZECOLUMNS
group_cols = ", ".join(
[f"'{safe_table}'[{col}]" for col in self.group_by_columns]
)
measure_clauses = ", ".join(
[f'"{measure}", [{measure}]' for measure in self.measures]
)
return f"EVALUATE SUMMARIZECOLUMNS({group_cols}, {measure_clauses})"
def _dax_query_batch(self, last_value: str | int | None = None) -> str:
"""Generate a batched DAX query using TOPN and keyset pagination.
Uses ORDER BY with the order_by_column for deterministic ordering,
and FILTER to skip already-fetched rows based on the last seen value.
Args:
last_value: The last value of order_by_column from the previous batch.
None for the first batch.
Returns:
DAX query string for fetching the next batch.
"""
safe_table = self.table_name.replace("'", "''")
order_col = self.order_by_column
if last_value is None:
# First batch: just use TOPN with ORDER BY
return (
f"EVALUATE TOPN({self.batch_size}, '{safe_table}', "
f"'{safe_table}'[{order_col}], ASC)"
)
# Subsequent batches: filter rows where order_col > last_value
# Handle string vs numeric values
if isinstance(last_value, str):
filter_value = f'"{last_value}"'
else:
filter_value = str(last_value)
return (
f"EVALUATE TOPN({self.batch_size}, "
f"FILTER('{safe_table}', '{safe_table}'[{order_col}] > {filter_value}), "
f"'{safe_table}'[{order_col}], ASC)"
)
async def _execute_query(self, dax_query: str) -> pd.DataFrame:
"""Execute a DAX query and return the results as a DataFrame.
Args:
dax_query: The DAX query string to execute.
Returns:
DataFrame containing the query results.
"""
url = f"{self.base_url}/datasets/{self.dataset_id}/executeQueries"
body = {
"queries": [{"query": dax_query}],
"serializerSettings": {"includeNulls": self.include_nulls},
}
headers = {
"Authorization": f"Bearer {self.access_token}",
"Content-Type": "application/json",
}
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(url, headers=headers, json=body)
if resp.status_code != 200:
raise RuntimeError(
f"Power BI executeQueries failed: {resp.status_code} - {resp.text}"
)
payload = resp.json()
try:
rows = payload["results"][0]["tables"][0]["rows"]
except (KeyError, IndexError) as e:
raise RuntimeError("Unexpected executeQueries response structure") from e
if not rows:
return pd.DataFrame()
df = pd.DataFrame(rows)
# Columns often come back as "Table[Column]". Strip the qualifier.
def _strip_qual(col: str) -> str:
if "[" in col and col.endswith("]"):
return col.split("[", 1)[1][:-1]
return col
df.columns = [_strip_qual(c) for c in df.columns]
return df
async def read_data(self) -> pd.DataFrame:
"""Fetch data from Power BI, using batching if order_by_column is set.
If order_by_column is configured, fetches data in batches using
keyset pagination to avoid the Power BI API's 1M value limit.
Otherwise, fetches all data in a single query (legacy behavior).
Returns:
DataFrame containing all fetched data.
"""
# Legacy behavior: no batching if order_by_column not set
if not self.order_by_column:
return await self._execute_query(self._dax_query())
# Batch fetching with keyset pagination
all_dfs: list[pd.DataFrame] = []
last_value: str | int | None = None
batch_num = 0
while True:
batch_num += 1
dax_query = self._dax_query_batch(last_value)
df = await self._execute_query(dax_query)
if df.empty:
# No more data to fetch
break
all_dfs.append(df)
# Get the last value for the next batch
new_last_value = df[self.order_by_column].iloc[-1]
# Safety check: if last_value didn't change, we're stuck in a loop
if new_last_value == last_value:
break
last_value = new_last_value
if not all_dfs:
return pd.DataFrame()
result = pd.concat(all_dfs, ignore_index=True)
return result
@staticmethod
def _get_access_token_sync(
tenant_id: str,
client_id: str,
client_secret: str,
*,
authority_base: str = settings.POWERBI_AUTHORITY_BASE,
cache: SerializableTokenCache | None = None,
) -> str:
SCOPE = ["https://analysis.windows.net/powerbi/api/.default"] # local scope
authority = f"{authority_base}/{tenant_id}"
app = ConfidentialClientApplication(
client_id=client_id,
authority=authority,
client_credential=client_secret,
token_cache=cache, # pass a SerializableTokenCache to reuse tokens
)
# Try cache first; fall back to client credentials
result = app.acquire_token_silent(
SCOPE, account=None
) or app.acquire_token_for_client(scopes=SCOPE)
if "access_token" not in result:
raise RuntimeError(
f"MSAL token error: {result.get('error')} - {result.get('error_description')}"
)
return result["access_token"]
@staticmethod
async def _get_access_token_async(
tenant_id: str,
client_id: str,
client_secret: str,
**kwargs,
) -> str:
# Offload the blocking MSAL HTTP call to a worker thread
return await anyio.to_thread.run_sync(
PowerBIReader._get_access_token_sync,
tenant_id,
client_id,
client_secret,
**kwargs,
)