fixed billing transactions mapping and added reporting
This commit is contained in:
parent
e4d41965f3
commit
34deb5f23d
6 changed files with 273 additions and 77 deletions
|
|
@ -162,6 +162,7 @@ class AiCallResponse(BaseModel):
|
|||
|
||||
content: str = Field(description="AI response content")
|
||||
modelName: str = Field(description="Selected model name")
|
||||
provider: str = Field(default="unknown", description="AI provider / connectorType (anthropic, openai, perplexity, etc.)")
|
||||
priceCHF: float = Field(default=0.0, description="Calculated price in USD")
|
||||
processingTime: float = Field(default=0.0, description="Duration in seconds")
|
||||
bytesSent: int = Field(default=0, description="Input data size in bytes")
|
||||
|
|
|
|||
|
|
@ -229,6 +229,7 @@ class AiObjects:
|
|||
return AiCallResponse(
|
||||
content=content,
|
||||
modelName=model.name,
|
||||
provider=model.connectorType,
|
||||
priceCHF=priceCHF,
|
||||
processingTime=processingTime,
|
||||
bytesSent=inputBytes,
|
||||
|
|
|
|||
|
|
@ -1174,7 +1174,7 @@ class BillingObjects:
|
|||
|
||||
def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all user-level transactions for specified mandates.
|
||||
Get all transactions for specified mandates (both USER and MANDATE accounts).
|
||||
|
||||
Args:
|
||||
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
|
||||
|
|
@ -1190,9 +1190,8 @@ class BillingObjects:
|
|||
try:
|
||||
appInterface = getAppInterface(self.currentUser)
|
||||
|
||||
# Get all user accounts
|
||||
accountFilter = {"accountType": AccountTypeEnum.USER.value}
|
||||
allAccounts = self.db.getRecordset(BillingAccount, recordFilter=accountFilter)
|
||||
# Get ALL accounts (both USER and MANDATE types) to cover all billing models
|
||||
allAccounts = self.db.getRecordset(BillingAccount)
|
||||
|
||||
# Filter by mandate if specified
|
||||
if mandateIds:
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ from modules.auth import limiter, requireSysAdmin, getRequestContext, RequestCon
|
|||
# Import billing components
|
||||
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
|
||||
from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
from modules.routes.routeDataUsers import _applyFiltersAndSort
|
||||
from modules.datamodels.datamodelBilling import (
|
||||
BillingAccount,
|
||||
BillingTransaction,
|
||||
|
|
@ -779,22 +781,212 @@ async def getUserViewBalances(
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/view/users/transactions", response_model=List[UserTransactionResponse])
|
||||
class ViewStatisticsResponse(BaseModel):
|
||||
"""Aggregated statistics across all user's mandates."""
|
||||
totalCost: float = 0.0
|
||||
transactionCount: int = 0
|
||||
costByProvider: Dict[str, float] = {}
|
||||
costByFeature: Dict[str, float] = {}
|
||||
costByMandate: Dict[str, float] = {}
|
||||
timeSeries: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
@router.get("/view/statistics")
|
||||
@limiter.limit("30/minute")
|
||||
async def getUserViewStatistics(
|
||||
request: Request,
|
||||
period: str = Query(default="month", description="Period: 'day' or 'month'"),
|
||||
year: int = Query(default=None, description="Year"),
|
||||
month: Optional[int] = Query(None, description="Month (1-12, required for period='day')"),
|
||||
ctx: RequestContext = Depends(getRequestContext)
|
||||
) -> ViewStatisticsResponse:
|
||||
"""
|
||||
Get aggregated usage statistics across all user's mandates.
|
||||
- period='month': returns monthly time series for the given year
|
||||
- period='day': returns daily time series for the given month/year
|
||||
"""
|
||||
try:
|
||||
from datetime import timedelta
|
||||
|
||||
if year is None:
|
||||
year = datetime.now().year
|
||||
|
||||
if period == "day" and not month:
|
||||
month = datetime.now().month
|
||||
|
||||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||||
|
||||
# Get all mandates the user has access to
|
||||
if ctx.user.isSysAdmin:
|
||||
mandateIds = None
|
||||
else:
|
||||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||||
appInterface = getAppInterface(ctx.user)
|
||||
userMandates = appInterface.getUserMandates(ctx.user.id)
|
||||
mandateIds = []
|
||||
for um in userMandates:
|
||||
mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
|
||||
if mandateId:
|
||||
mandateIds.append(mandateId)
|
||||
if not mandateIds:
|
||||
logger.warning("No mandate IDs found for user")
|
||||
return ViewStatisticsResponse()
|
||||
|
||||
# Get all transactions
|
||||
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
|
||||
logger.info(f"View statistics: {len(allTransactions)} total transactions fetched for period={period}, year={year}, month={month}")
|
||||
|
||||
# Calculate date range
|
||||
if period == "day":
|
||||
startDate = date(year, month, 1)
|
||||
if month == 12:
|
||||
endDate = date(year + 1, 1, 1)
|
||||
else:
|
||||
endDate = date(year, month + 1, 1)
|
||||
else:
|
||||
startDate = date(year, 1, 1)
|
||||
endDate = date(year + 1, 1, 1)
|
||||
|
||||
# Filter by date range and only DEBIT transactions
|
||||
debits = []
|
||||
skippedNoDate = 0
|
||||
skippedDateRange = 0
|
||||
skippedNotDebit = 0
|
||||
|
||||
for t in allTransactions:
|
||||
createdAt = t.get("_createdAt")
|
||||
if not createdAt:
|
||||
skippedNoDate += 1
|
||||
continue
|
||||
|
||||
# Parse date from various formats (DB stores as DOUBLE PRECISION / Unix timestamp)
|
||||
txDate = None
|
||||
if isinstance(createdAt, (int, float)):
|
||||
txDate = datetime.fromtimestamp(createdAt).date()
|
||||
elif isinstance(createdAt, datetime):
|
||||
txDate = createdAt.date()
|
||||
elif isinstance(createdAt, date) and not isinstance(createdAt, datetime):
|
||||
txDate = createdAt
|
||||
elif isinstance(createdAt, str):
|
||||
try:
|
||||
# Try as float string first (Unix timestamp)
|
||||
txDate = datetime.fromtimestamp(float(createdAt)).date()
|
||||
except (ValueError, TypeError):
|
||||
try:
|
||||
txDate = datetime.fromisoformat(createdAt.replace("Z", "+00:00")).date()
|
||||
except (ValueError, TypeError):
|
||||
skippedNoDate += 1
|
||||
continue
|
||||
else:
|
||||
skippedNoDate += 1
|
||||
continue
|
||||
|
||||
if txDate < startDate or txDate >= endDate:
|
||||
skippedDateRange += 1
|
||||
continue
|
||||
|
||||
# Compare transactionType - handle both string and enum
|
||||
txType = t.get("transactionType")
|
||||
txTypeStr = str(txType) if txType is not None else ""
|
||||
if txTypeStr != "DEBIT" and txTypeStr != "TransactionTypeEnum.DEBIT":
|
||||
# Also check .value for enum objects
|
||||
txTypeValue = getattr(txType, 'value', txTypeStr)
|
||||
if txTypeValue != "DEBIT":
|
||||
skippedNotDebit += 1
|
||||
continue
|
||||
|
||||
t["_txDate"] = txDate
|
||||
debits.append(t)
|
||||
|
||||
logger.info(f"View statistics: {len(debits)} DEBIT transactions after filter. "
|
||||
f"Skipped: noDate={skippedNoDate}, dateRange={skippedDateRange}, notDebit={skippedNotDebit}")
|
||||
|
||||
# Aggregate totals
|
||||
totalCost = sum(t.get("amount", 0) for t in debits)
|
||||
|
||||
costByProvider: Dict[str, float] = {}
|
||||
costByFeature: Dict[str, float] = {}
|
||||
costByMandate: Dict[str, float] = {}
|
||||
|
||||
for t in debits:
|
||||
provider = t.get("aicoreProvider") or "unknown"
|
||||
costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0)
|
||||
|
||||
mandate = t.get("mandateName") or t.get("mandateId") or "unknown"
|
||||
featureCode = t.get("featureCode") or "unknown"
|
||||
featureKey = f"{mandate} / {featureCode}"
|
||||
costByFeature[featureKey] = costByFeature.get(featureKey, 0) + t.get("amount", 0)
|
||||
|
||||
mandate = t.get("mandateName") or t.get("mandateId") or "unknown"
|
||||
costByMandate[mandate] = costByMandate.get(mandate, 0) + t.get("amount", 0)
|
||||
|
||||
# Build time series (raw data only, no display logic)
|
||||
timeSeries = []
|
||||
if period == "day":
|
||||
numDays = (endDate - startDate).days
|
||||
for day in range(numDays):
|
||||
d = startDate + timedelta(days=day)
|
||||
dayCost = sum(t.get("amount", 0) for t in debits if t["_txDate"] == d)
|
||||
dayCount = sum(1 for t in debits if t["_txDate"] == d)
|
||||
if dayCost > 0 or dayCount > 0:
|
||||
timeSeries.append({
|
||||
"date": d.isoformat(),
|
||||
"cost": round(dayCost, 4),
|
||||
"count": dayCount
|
||||
})
|
||||
else:
|
||||
for m in range(1, 13):
|
||||
mStart = date(year, m, 1)
|
||||
mEnd = date(year, m + 1, 1) if m < 12 else date(year + 1, 1, 1)
|
||||
monthCost = sum(t.get("amount", 0) for t in debits if mStart <= t["_txDate"] < mEnd)
|
||||
monthCount = sum(1 for t in debits if mStart <= t["_txDate"] < mEnd)
|
||||
timeSeries.append({
|
||||
"date": f"{year}-{m:02d}",
|
||||
"cost": round(monthCost, 4),
|
||||
"count": monthCount
|
||||
})
|
||||
|
||||
return ViewStatisticsResponse(
|
||||
totalCost=round(totalCost, 4),
|
||||
transactionCount=len(debits),
|
||||
costByProvider=costByProvider,
|
||||
costByFeature=costByFeature,
|
||||
costByMandate=costByMandate,
|
||||
timeSeries=timeSeries
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting view statistics: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/view/users/transactions", response_model=PaginatedResponse[UserTransactionResponse])
|
||||
@limiter.limit("30/minute")
|
||||
async def getUserViewTransactions(
|
||||
request: Request,
|
||||
limit: int = Query(default=100, ge=1, le=1000),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
ctx: RequestContext = Depends(getRequestContext)
|
||||
):
|
||||
) -> PaginatedResponse[UserTransactionResponse]:
|
||||
"""
|
||||
Get user-level transactions.
|
||||
Get user-level transactions with pagination support.
|
||||
- SysAdmin: sees all user transactions across all mandates
|
||||
- MandateAdmin: sees user transactions for mandates they manage
|
||||
- Regular user: sees only their own transactions
|
||||
|
||||
Query Parameters:
|
||||
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||||
"""
|
||||
try:
|
||||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||||
|
||||
# Parse pagination params
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
import json
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
|
||||
# Determine which mandates the user has access to
|
||||
if ctx.user.isSysAdmin:
|
||||
# SysAdmin sees all
|
||||
|
|
@ -812,34 +1004,78 @@ async def getUserViewTransactions(
|
|||
mandateIds.append(mandateId)
|
||||
|
||||
if not mandateIds:
|
||||
return []
|
||||
return PaginatedResponse(items=[], pagination=None)
|
||||
|
||||
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=limit)
|
||||
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
|
||||
|
||||
# Non-admin users only see their own transactions
|
||||
if not ctx.user.isSysAdmin:
|
||||
allTransactions = [t for t in allTransactions if t.get("userId") == ctx.user.id]
|
||||
logger.debug(f"Found {len(allTransactions)} transactions for mandates {mandateIds}")
|
||||
|
||||
result = []
|
||||
# Convert to response objects as dicts for filtering/sorting
|
||||
transactionDicts = []
|
||||
for t in allTransactions:
|
||||
result.append(UserTransactionResponse(
|
||||
id=t.get("id"),
|
||||
accountId=t.get("accountId"),
|
||||
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
|
||||
amount=t.get("amount", 0.0),
|
||||
description=t.get("description", ""),
|
||||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||||
workflowId=t.get("workflowId"),
|
||||
featureCode=t.get("featureCode"),
|
||||
aicoreProvider=t.get("aicoreProvider"),
|
||||
createdAt=t.get("_createdAt"),
|
||||
mandateId=t.get("mandateId"),
|
||||
mandateName=t.get("mandateName"),
|
||||
userId=t.get("userId"),
|
||||
userName=t.get("userName")
|
||||
))
|
||||
transactionDicts.append({
|
||||
"id": t.get("id"),
|
||||
"accountId": t.get("accountId"),
|
||||
"transactionType": t.get("transactionType", "DEBIT"),
|
||||
"amount": t.get("amount", 0.0),
|
||||
"description": t.get("description", ""),
|
||||
"referenceType": t.get("referenceType"),
|
||||
"workflowId": t.get("workflowId"),
|
||||
"featureCode": t.get("featureCode"),
|
||||
"aicoreProvider": t.get("aicoreProvider"),
|
||||
"createdAt": t.get("_createdAt"),
|
||||
"mandateId": t.get("mandateId"),
|
||||
"mandateName": t.get("mandateName"),
|
||||
"userId": t.get("userId"),
|
||||
"userName": t.get("userName"),
|
||||
})
|
||||
|
||||
return result
|
||||
# Apply filters and sorting
|
||||
filteredDicts = _applyFiltersAndSort(transactionDicts, paginationParams)
|
||||
|
||||
# Convert to response models
|
||||
def _toResponse(d):
|
||||
return UserTransactionResponse(
|
||||
id=d.get("id"),
|
||||
accountId=d.get("accountId"),
|
||||
transactionType=TransactionTypeEnum(d.get("transactionType", "DEBIT")),
|
||||
amount=d.get("amount", 0.0),
|
||||
description=d.get("description", ""),
|
||||
referenceType=ReferenceTypeEnum(d["referenceType"]) if d.get("referenceType") else None,
|
||||
workflowId=d.get("workflowId"),
|
||||
featureCode=d.get("featureCode"),
|
||||
aicoreProvider=d.get("aicoreProvider"),
|
||||
createdAt=d.get("createdAt"),
|
||||
mandateId=d.get("mandateId"),
|
||||
mandateName=d.get("mandateName"),
|
||||
userId=d.get("userId"),
|
||||
userName=d.get("userName")
|
||||
)
|
||||
|
||||
if paginationParams:
|
||||
import math
|
||||
totalItems = len(filteredDicts)
|
||||
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
||||
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
||||
endIdx = startIdx + paginationParams.pageSize
|
||||
paginatedDicts = filteredDicts[startIdx:endIdx]
|
||||
|
||||
return PaginatedResponse(
|
||||
items=[_toResponse(d) for d in paginatedDicts],
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PaginatedResponse(
|
||||
items=[_toResponse(d) for d in filteredDicts],
|
||||
pagination=None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user view transactions: {e}")
|
||||
|
|
|
|||
|
|
@ -729,9 +729,8 @@ class ChatService:
|
|||
if not priceCHF or priceCHF <= 0:
|
||||
return
|
||||
|
||||
# Extract provider from model name (e.g., "anthropic.claude-3-sonnet" -> "anthropic")
|
||||
modelName = getattr(aiResponse, 'modelName', '') or ''
|
||||
aicoreProvider = modelName.split('.')[0] if '.' in modelName else 'unknown'
|
||||
# Get provider from AiCallResponse (set from model.connectorType)
|
||||
aicoreProvider = getattr(aiResponse, 'provider', None) or 'unknown'
|
||||
|
||||
# Get feature context if available
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
|
|
|
|||
|
|
@ -229,6 +229,7 @@ class ExtractionService:
|
|||
aiResponse = AiCallResponse(
|
||||
content="", # No content for extraction stats needed
|
||||
modelName=model.name,
|
||||
provider=model.connectorType,
|
||||
priceCHF=priceCHF,
|
||||
processingTime=processingTime,
|
||||
bytesSent=bytesSent,
|
||||
|
|
@ -1311,6 +1312,7 @@ class ExtractionService:
|
|||
return AiCallResponse(
|
||||
content=modelResponse.content,
|
||||
modelName=model.name,
|
||||
provider=model.connectorType,
|
||||
priceCHF=0.0,
|
||||
processingTime=processingTime,
|
||||
bytesSent=0,
|
||||
|
|
@ -1416,6 +1418,7 @@ class ExtractionService:
|
|||
return AiCallResponse(
|
||||
content=mergedContent,
|
||||
modelName=model.name,
|
||||
provider=model.connectorType,
|
||||
priceCHF=sum(r.priceCHF for r in chunkResults),
|
||||
processingTime=sum(r.processingTime for r in chunkResults),
|
||||
bytesSent=sum(r.bytesSent for r in chunkResults),
|
||||
|
|
@ -1428,49 +1431,6 @@ class ExtractionService:
|
|||
response = await aiObjects._callWithModel(model, prompt, contentPart.data, options)
|
||||
logger.info(f"✅ Content part processed successfully with model: {model.name}")
|
||||
return response
|
||||
chunks = await self.chunkContentPartForAi(contentPart, model, options, prompt)
|
||||
if not chunks:
|
||||
raise ValueError(f"Failed to chunk content part for model {model.name}")
|
||||
|
||||
logger.info(f"Starting to process {len(chunks)} chunks with model {model.name}")
|
||||
|
||||
if progressCallback:
|
||||
progressCallback(0.0, f"Starting to process {len(chunks)} chunks")
|
||||
|
||||
chunkResults = []
|
||||
for idx, chunk in enumerate(chunks):
|
||||
chunkNum = idx + 1
|
||||
chunkData = chunk.get('data', '')
|
||||
logger.info(f"Processing chunk {chunkNum}/{len(chunks)} with model {model.name}")
|
||||
|
||||
if progressCallback:
|
||||
progressCallback(chunkNum / len(chunks), f"Processing chunk {chunkNum}/{len(chunks)}")
|
||||
|
||||
try:
|
||||
chunkResponse = await aiObjects._callWithModel(model, prompt, chunkData, options)
|
||||
chunkResults.append(chunkResponse)
|
||||
logger.info(f"✅ Chunk {chunkNum}/{len(chunks)} processed successfully")
|
||||
|
||||
if progressCallback:
|
||||
progressCallback(chunkNum / len(chunks), f"Chunk {chunkNum}/{len(chunks)} processed")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error processing chunk {chunkNum}/{len(chunks)}: {str(e)}")
|
||||
raise
|
||||
|
||||
# Merge chunk results using unified mergePartResults
|
||||
# Pass original contentPart to preserve typeGroup for all chunks (one-to-many: 1 part -> N chunks)
|
||||
mergedContent = self.mergePartResults(chunkResults, options, [contentPart])
|
||||
|
||||
logger.info(f"✅ Content part chunked and processed with model: {model.name} ({len(chunks)} chunks)")
|
||||
return AiCallResponse(
|
||||
content=mergedContent,
|
||||
modelName=model.name,
|
||||
priceCHF=sum(r.priceCHF for r in chunkResults),
|
||||
processingTime=sum(r.processingTime for r in chunkResults),
|
||||
bytesSent=sum(r.bytesSent for r in chunkResults),
|
||||
bytesReceived=sum(r.bytesReceived for r in chunkResults),
|
||||
errorCount=sum(r.errorCount for r in chunkResults)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
lastError = e
|
||||
|
|
|
|||
Loading…
Reference in a new issue