259 lines
8.6 KiB
Python
259 lines
8.6 KiB
Python
from msal import ConfidentialClientApplication, SerializableTokenCache
|
|
import anyio # comes with FastAPI via Starlette/AnyIO
|
|
from dataclasses import dataclass
|
|
from src.settings import settings
|
|
import pandas as pd
|
|
import httpx
|
|
|
|
|
|
@dataclass
|
|
class PowerBIReader:
|
|
dataset_id: str
|
|
access_token: str
|
|
table_name: str
|
|
base_url: str = settings.POWERBI_BASE_URL
|
|
include_nulls: bool = True
|
|
measures: list[str] = None
|
|
group_by_columns: list[str] = None
|
|
batch_size: int = 10000
|
|
order_by_column: str | None = None
|
|
|
|
@classmethod
|
|
async def create(
|
|
cls,
|
|
*,
|
|
dataset_id: str,
|
|
access_token: str,
|
|
table_name: str,
|
|
measures: list[str] = None,
|
|
group_by_columns: list[str] = None,
|
|
batch_size: int = 10000,
|
|
order_by_column: str | None = None,
|
|
**kwargs,
|
|
) -> "PowerBIReader":
|
|
return cls(
|
|
dataset_id=dataset_id,
|
|
access_token=access_token,
|
|
table_name=table_name,
|
|
measures=measures or [],
|
|
group_by_columns=group_by_columns or [],
|
|
batch_size=batch_size,
|
|
order_by_column=order_by_column,
|
|
**kwargs,
|
|
)
|
|
|
|
def _dax_query(self) -> str:
|
|
"""Generate DAX query based on configuration.
|
|
|
|
Generates different DAX queries depending on whether measures and/or
|
|
group_by_columns are specified:
|
|
|
|
1. No measures: EVALUATE 'TableName'
|
|
Returns all physical/calculated columns from the table.
|
|
|
|
2. Measures only: EVALUATE ADDCOLUMNS('TableName', "Measure1", [Measure1], ...)
|
|
Returns all columns plus the specified measures.
|
|
|
|
3. Measures + group_by_columns: EVALUATE SUMMARIZECOLUMNS('Table'[Col1], ..., "Measure1", [Measure1], ...)
|
|
Returns aggregated measures grouped by the specified columns.
|
|
|
|
Returns:
|
|
DAX query string to execute against Power BI.
|
|
"""
|
|
# Escape single quotes in table names per DAX rules
|
|
safe_table = self.table_name.replace("'", "''")
|
|
|
|
# Case 1: No measures - simple table evaluation
|
|
if not self.measures:
|
|
return f"EVALUATE '{safe_table}'"
|
|
|
|
# Case 2: Measures without grouping - use ADDCOLUMNS
|
|
if not self.group_by_columns:
|
|
measure_clauses = ", ".join(
|
|
[f'"{measure}", [{measure}]' for measure in self.measures]
|
|
)
|
|
return f"EVALUATE ADDCOLUMNS('{safe_table}', {measure_clauses})"
|
|
|
|
# Case 3: Measures with grouping - use SUMMARIZECOLUMNS
|
|
group_cols = ", ".join(
|
|
[f"'{safe_table}'[{col}]" for col in self.group_by_columns]
|
|
)
|
|
measure_clauses = ", ".join(
|
|
[f'"{measure}", [{measure}]' for measure in self.measures]
|
|
)
|
|
return f"EVALUATE SUMMARIZECOLUMNS({group_cols}, {measure_clauses})"
|
|
|
|
def _dax_query_batch(self, last_value: str | int | None = None) -> str:
|
|
"""Generate a batched DAX query using TOPN and keyset pagination.
|
|
|
|
Uses ORDER BY with the order_by_column for deterministic ordering,
|
|
and FILTER to skip already-fetched rows based on the last seen value.
|
|
|
|
Args:
|
|
last_value: The last value of order_by_column from the previous batch.
|
|
None for the first batch.
|
|
|
|
Returns:
|
|
DAX query string for fetching the next batch.
|
|
"""
|
|
safe_table = self.table_name.replace("'", "''")
|
|
order_col = self.order_by_column
|
|
|
|
if last_value is None:
|
|
# First batch: just use TOPN with ORDER BY
|
|
return (
|
|
f"EVALUATE TOPN({self.batch_size}, '{safe_table}', "
|
|
f"'{safe_table}'[{order_col}], ASC)"
|
|
)
|
|
|
|
# Subsequent batches: filter rows where order_col > last_value
|
|
# Handle string vs numeric values
|
|
if isinstance(last_value, str):
|
|
filter_value = f'"{last_value}"'
|
|
else:
|
|
filter_value = str(last_value)
|
|
|
|
return (
|
|
f"EVALUATE TOPN({self.batch_size}, "
|
|
f"FILTER('{safe_table}', '{safe_table}'[{order_col}] > {filter_value}), "
|
|
f"'{safe_table}'[{order_col}], ASC)"
|
|
)
|
|
|
|
async def _execute_query(self, dax_query: str) -> pd.DataFrame:
|
|
"""Execute a DAX query and return the results as a DataFrame.
|
|
|
|
Args:
|
|
dax_query: The DAX query string to execute.
|
|
|
|
Returns:
|
|
DataFrame containing the query results.
|
|
"""
|
|
url = f"{self.base_url}/datasets/{self.dataset_id}/executeQueries"
|
|
body = {
|
|
"queries": [{"query": dax_query}],
|
|
"serializerSettings": {"includeNulls": self.include_nulls},
|
|
}
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self.access_token}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
resp = await client.post(url, headers=headers, json=body)
|
|
|
|
if resp.status_code != 200:
|
|
raise RuntimeError(
|
|
f"Power BI executeQueries failed: {resp.status_code} - {resp.text}"
|
|
)
|
|
|
|
payload = resp.json()
|
|
try:
|
|
rows = payload["results"][0]["tables"][0]["rows"]
|
|
except (KeyError, IndexError) as e:
|
|
raise RuntimeError("Unexpected executeQueries response structure") from e
|
|
|
|
if not rows:
|
|
return pd.DataFrame()
|
|
|
|
df = pd.DataFrame(rows)
|
|
|
|
# Columns often come back as "Table[Column]". Strip the qualifier.
|
|
def _strip_qual(col: str) -> str:
|
|
if "[" in col and col.endswith("]"):
|
|
return col.split("[", 1)[1][:-1]
|
|
return col
|
|
|
|
df.columns = [_strip_qual(c) for c in df.columns]
|
|
return df
|
|
|
|
async def read_data(self) -> pd.DataFrame:
|
|
"""Fetch data from Power BI, using batching if order_by_column is set.
|
|
|
|
If order_by_column is configured, fetches data in batches using
|
|
keyset pagination to avoid the Power BI API's 1M value limit.
|
|
Otherwise, fetches all data in a single query (legacy behavior).
|
|
|
|
Returns:
|
|
DataFrame containing all fetched data.
|
|
"""
|
|
# Legacy behavior: no batching if order_by_column not set
|
|
if not self.order_by_column:
|
|
return await self._execute_query(self._dax_query())
|
|
|
|
# Batch fetching with keyset pagination
|
|
all_dfs: list[pd.DataFrame] = []
|
|
last_value: str | int | None = None
|
|
batch_num = 0
|
|
|
|
while True:
|
|
batch_num += 1
|
|
dax_query = self._dax_query_batch(last_value)
|
|
df = await self._execute_query(dax_query)
|
|
|
|
if df.empty:
|
|
# No more data to fetch
|
|
break
|
|
|
|
all_dfs.append(df)
|
|
|
|
# Get the last value for the next batch
|
|
new_last_value = df[self.order_by_column].iloc[-1]
|
|
|
|
# Safety check: if last_value didn't change, we're stuck in a loop
|
|
if new_last_value == last_value:
|
|
break
|
|
|
|
last_value = new_last_value
|
|
|
|
if not all_dfs:
|
|
return pd.DataFrame()
|
|
|
|
result = pd.concat(all_dfs, ignore_index=True)
|
|
return result
|
|
|
|
@staticmethod
|
|
def _get_access_token_sync(
|
|
tenant_id: str,
|
|
client_id: str,
|
|
client_secret: str,
|
|
*,
|
|
authority_base: str = settings.POWERBI_AUTHORITY_BASE,
|
|
cache: SerializableTokenCache | None = None,
|
|
) -> str:
|
|
SCOPE = ["https://analysis.windows.net/powerbi/api/.default"] # local scope
|
|
authority = f"{authority_base}/{tenant_id}"
|
|
|
|
app = ConfidentialClientApplication(
|
|
client_id=client_id,
|
|
authority=authority,
|
|
client_credential=client_secret,
|
|
token_cache=cache, # pass a SerializableTokenCache to reuse tokens
|
|
)
|
|
|
|
# Try cache first; fall back to client credentials
|
|
result = app.acquire_token_silent(
|
|
SCOPE, account=None
|
|
) or app.acquire_token_for_client(scopes=SCOPE)
|
|
|
|
if "access_token" not in result:
|
|
raise RuntimeError(
|
|
f"MSAL token error: {result.get('error')} - {result.get('error_description')}"
|
|
)
|
|
return result["access_token"]
|
|
|
|
@staticmethod
|
|
async def _get_access_token_async(
|
|
tenant_id: str,
|
|
client_id: str,
|
|
client_secret: str,
|
|
**kwargs,
|
|
) -> str:
|
|
# Offload the blocking MSAL HTTP call to a worker thread
|
|
return await anyio.to_thread.run_sync(
|
|
PowerBIReader._get_access_token_sync,
|
|
tenant_id,
|
|
client_id,
|
|
client_secret,
|
|
**kwargs,
|
|
)
|