fixed billing transactions mapping and added reporting

This commit is contained in:
patrick-motsch 2026-02-08 13:15:19 +01:00
parent e4d41965f3
commit 34deb5f23d
6 changed files with 273 additions and 77 deletions

View file

@ -162,6 +162,7 @@ class AiCallResponse(BaseModel):
content: str = Field(description="AI response content") content: str = Field(description="AI response content")
modelName: str = Field(description="Selected model name") modelName: str = Field(description="Selected model name")
provider: str = Field(default="unknown", description="AI provider / connectorType (anthropic, openai, perplexity, etc.)")
priceCHF: float = Field(default=0.0, description="Calculated price in USD") priceCHF: float = Field(default=0.0, description="Calculated price in USD")
processingTime: float = Field(default=0.0, description="Duration in seconds") processingTime: float = Field(default=0.0, description="Duration in seconds")
bytesSent: int = Field(default=0, description="Input data size in bytes") bytesSent: int = Field(default=0, description="Input data size in bytes")

View file

@ -229,6 +229,7 @@ class AiObjects:
return AiCallResponse( return AiCallResponse(
content=content, content=content,
modelName=model.name, modelName=model.name,
provider=model.connectorType,
priceCHF=priceCHF, priceCHF=priceCHF,
processingTime=processingTime, processingTime=processingTime,
bytesSent=inputBytes, bytesSent=inputBytes,

View file

@ -1174,7 +1174,7 @@ class BillingObjects:
def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]: def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
""" """
Get all user-level transactions for specified mandates. Get all transactions for specified mandates (both USER and MANDATE accounts).
Args: Args:
mandateIds: Optional list of mandate IDs to filter. If None, returns all. mandateIds: Optional list of mandate IDs to filter. If None, returns all.
@ -1190,9 +1190,8 @@ class BillingObjects:
try: try:
appInterface = getAppInterface(self.currentUser) appInterface = getAppInterface(self.currentUser)
# Get all user accounts # Get ALL accounts (both USER and MANDATE types) to cover all billing models
accountFilter = {"accountType": AccountTypeEnum.USER.value} allAccounts = self.db.getRecordset(BillingAccount)
allAccounts = self.db.getRecordset(BillingAccount, recordFilter=accountFilter)
# Filter by mandate if specified # Filter by mandate if specified
if mandateIds: if mandateIds:

View file

@ -22,6 +22,8 @@ from modules.auth import limiter, requireSysAdmin, getRequestContext, RequestCon
# Import billing components # Import billing components
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
from modules.routes.routeDataUsers import _applyFiltersAndSort
from modules.datamodels.datamodelBilling import ( from modules.datamodels.datamodelBilling import (
BillingAccount, BillingAccount,
BillingTransaction, BillingTransaction,
@ -779,22 +781,212 @@ async def getUserViewBalances(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("/view/users/transactions", response_model=List[UserTransactionResponse]) class ViewStatisticsResponse(BaseModel):
"""Aggregated statistics across all user's mandates."""
totalCost: float = 0.0
transactionCount: int = 0
costByProvider: Dict[str, float] = {}
costByFeature: Dict[str, float] = {}
costByMandate: Dict[str, float] = {}
timeSeries: List[Dict[str, Any]] = []
@router.get("/view/statistics")
@limiter.limit("30/minute")
async def getUserViewStatistics(
request: Request,
period: str = Query(default="month", description="Period: 'day' or 'month'"),
year: int = Query(default=None, description="Year"),
month: Optional[int] = Query(None, description="Month (1-12, required for period='day')"),
ctx: RequestContext = Depends(getRequestContext)
) -> ViewStatisticsResponse:
"""
Get aggregated usage statistics across all user's mandates.
- period='month': returns monthly time series for the given year
- period='day': returns daily time series for the given month/year
"""
try:
from datetime import timedelta
if year is None:
year = datetime.now().year
if period == "day" and not month:
month = datetime.now().month
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
# Get all mandates the user has access to
if ctx.user.isSysAdmin:
mandateIds = None
else:
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
appInterface = getAppInterface(ctx.user)
userMandates = appInterface.getUserMandates(ctx.user.id)
mandateIds = []
for um in userMandates:
mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
if mandateId:
mandateIds.append(mandateId)
if not mandateIds:
logger.warning("No mandate IDs found for user")
return ViewStatisticsResponse()
# Get all transactions
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
logger.info(f"View statistics: {len(allTransactions)} total transactions fetched for period={period}, year={year}, month={month}")
# Calculate date range
if period == "day":
startDate = date(year, month, 1)
if month == 12:
endDate = date(year + 1, 1, 1)
else:
endDate = date(year, month + 1, 1)
else:
startDate = date(year, 1, 1)
endDate = date(year + 1, 1, 1)
# Filter by date range and only DEBIT transactions
debits = []
skippedNoDate = 0
skippedDateRange = 0
skippedNotDebit = 0
for t in allTransactions:
createdAt = t.get("_createdAt")
if not createdAt:
skippedNoDate += 1
continue
# Parse date from various formats (DB stores as DOUBLE PRECISION / Unix timestamp)
txDate = None
if isinstance(createdAt, (int, float)):
txDate = datetime.fromtimestamp(createdAt).date()
elif isinstance(createdAt, datetime):
txDate = createdAt.date()
elif isinstance(createdAt, date) and not isinstance(createdAt, datetime):
txDate = createdAt
elif isinstance(createdAt, str):
try:
# Try as float string first (Unix timestamp)
txDate = datetime.fromtimestamp(float(createdAt)).date()
except (ValueError, TypeError):
try:
txDate = datetime.fromisoformat(createdAt.replace("Z", "+00:00")).date()
except (ValueError, TypeError):
skippedNoDate += 1
continue
else:
skippedNoDate += 1
continue
if txDate < startDate or txDate >= endDate:
skippedDateRange += 1
continue
# Compare transactionType - handle both string and enum
txType = t.get("transactionType")
txTypeStr = str(txType) if txType is not None else ""
if txTypeStr != "DEBIT" and txTypeStr != "TransactionTypeEnum.DEBIT":
# Also check .value for enum objects
txTypeValue = getattr(txType, 'value', txTypeStr)
if txTypeValue != "DEBIT":
skippedNotDebit += 1
continue
t["_txDate"] = txDate
debits.append(t)
logger.info(f"View statistics: {len(debits)} DEBIT transactions after filter. "
f"Skipped: noDate={skippedNoDate}, dateRange={skippedDateRange}, notDebit={skippedNotDebit}")
# Aggregate totals
totalCost = sum(t.get("amount", 0) for t in debits)
costByProvider: Dict[str, float] = {}
costByFeature: Dict[str, float] = {}
costByMandate: Dict[str, float] = {}
for t in debits:
provider = t.get("aicoreProvider") or "unknown"
costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0)
mandate = t.get("mandateName") or t.get("mandateId") or "unknown"
featureCode = t.get("featureCode") or "unknown"
featureKey = f"{mandate} / {featureCode}"
costByFeature[featureKey] = costByFeature.get(featureKey, 0) + t.get("amount", 0)
mandate = t.get("mandateName") or t.get("mandateId") or "unknown"
costByMandate[mandate] = costByMandate.get(mandate, 0) + t.get("amount", 0)
# Build time series (raw data only, no display logic)
timeSeries = []
if period == "day":
numDays = (endDate - startDate).days
for day in range(numDays):
d = startDate + timedelta(days=day)
dayCost = sum(t.get("amount", 0) for t in debits if t["_txDate"] == d)
dayCount = sum(1 for t in debits if t["_txDate"] == d)
if dayCost > 0 or dayCount > 0:
timeSeries.append({
"date": d.isoformat(),
"cost": round(dayCost, 4),
"count": dayCount
})
else:
for m in range(1, 13):
mStart = date(year, m, 1)
mEnd = date(year, m + 1, 1) if m < 12 else date(year + 1, 1, 1)
monthCost = sum(t.get("amount", 0) for t in debits if mStart <= t["_txDate"] < mEnd)
monthCount = sum(1 for t in debits if mStart <= t["_txDate"] < mEnd)
timeSeries.append({
"date": f"{year}-{m:02d}",
"cost": round(monthCost, 4),
"count": monthCount
})
return ViewStatisticsResponse(
totalCost=round(totalCost, 4),
transactionCount=len(debits),
costByProvider=costByProvider,
costByFeature=costByFeature,
costByMandate=costByMandate,
timeSeries=timeSeries
)
except Exception as e:
logger.error(f"Error getting view statistics: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/view/users/transactions", response_model=PaginatedResponse[UserTransactionResponse])
@limiter.limit("30/minute") @limiter.limit("30/minute")
async def getUserViewTransactions( async def getUserViewTransactions(
request: Request, request: Request,
limit: int = Query(default=100, ge=1, le=1000), pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
ctx: RequestContext = Depends(getRequestContext) ctx: RequestContext = Depends(getRequestContext)
): ) -> PaginatedResponse[UserTransactionResponse]:
""" """
Get user-level transactions. Get user-level transactions with pagination support.
- SysAdmin: sees all user transactions across all mandates - SysAdmin: sees all user transactions across all mandates
- MandateAdmin: sees user transactions for mandates they manage - MandateAdmin: sees user transactions for mandates they manage
- Regular user: sees only their own transactions - Regular user: sees only their own transactions
Query Parameters:
- pagination: JSON-encoded PaginationParams object, or None for no pagination
""" """
try: try:
billingInterface = getBillingInterface(ctx.user, ctx.mandateId) billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
# Parse pagination params
paginationParams = None
if pagination:
import json
paginationDict = json.loads(pagination)
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
# Determine which mandates the user has access to # Determine which mandates the user has access to
if ctx.user.isSysAdmin: if ctx.user.isSysAdmin:
# SysAdmin sees all # SysAdmin sees all
@ -812,34 +1004,78 @@ async def getUserViewTransactions(
mandateIds.append(mandateId) mandateIds.append(mandateId)
if not mandateIds: if not mandateIds:
return [] return PaginatedResponse(items=[], pagination=None)
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=limit) allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
# Non-admin users only see their own transactions logger.debug(f"Found {len(allTransactions)} transactions for mandates {mandateIds}")
if not ctx.user.isSysAdmin:
allTransactions = [t for t in allTransactions if t.get("userId") == ctx.user.id]
result = [] # Convert to response objects as dicts for filtering/sorting
transactionDicts = []
for t in allTransactions: for t in allTransactions:
result.append(UserTransactionResponse( transactionDicts.append({
id=t.get("id"), "id": t.get("id"),
accountId=t.get("accountId"), "accountId": t.get("accountId"),
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")), "transactionType": t.get("transactionType", "DEBIT"),
amount=t.get("amount", 0.0), "amount": t.get("amount", 0.0),
description=t.get("description", ""), "description": t.get("description", ""),
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, "referenceType": t.get("referenceType"),
workflowId=t.get("workflowId"), "workflowId": t.get("workflowId"),
featureCode=t.get("featureCode"), "featureCode": t.get("featureCode"),
aicoreProvider=t.get("aicoreProvider"), "aicoreProvider": t.get("aicoreProvider"),
createdAt=t.get("_createdAt"), "createdAt": t.get("_createdAt"),
mandateId=t.get("mandateId"), "mandateId": t.get("mandateId"),
mandateName=t.get("mandateName"), "mandateName": t.get("mandateName"),
userId=t.get("userId"), "userId": t.get("userId"),
userName=t.get("userName") "userName": t.get("userName"),
)) })
return result # Apply filters and sorting
filteredDicts = _applyFiltersAndSort(transactionDicts, paginationParams)
# Convert to response models
def _toResponse(d):
return UserTransactionResponse(
id=d.get("id"),
accountId=d.get("accountId"),
transactionType=TransactionTypeEnum(d.get("transactionType", "DEBIT")),
amount=d.get("amount", 0.0),
description=d.get("description", ""),
referenceType=ReferenceTypeEnum(d["referenceType"]) if d.get("referenceType") else None,
workflowId=d.get("workflowId"),
featureCode=d.get("featureCode"),
aicoreProvider=d.get("aicoreProvider"),
createdAt=d.get("createdAt"),
mandateId=d.get("mandateId"),
mandateName=d.get("mandateName"),
userId=d.get("userId"),
userName=d.get("userName")
)
if paginationParams:
import math
totalItems = len(filteredDicts)
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
endIdx = startIdx + paginationParams.pageSize
paginatedDicts = filteredDicts[startIdx:endIdx]
return PaginatedResponse(
items=[_toResponse(d) for d in paginatedDicts],
pagination=PaginationMetadata(
currentPage=paginationParams.page,
pageSize=paginationParams.pageSize,
totalItems=totalItems,
totalPages=totalPages,
sort=paginationParams.sort,
filters=paginationParams.filters
)
)
else:
return PaginatedResponse(
items=[_toResponse(d) for d in filteredDicts],
pagination=None
)
except Exception as e: except Exception as e:
logger.error(f"Error getting user view transactions: {e}") logger.error(f"Error getting user view transactions: {e}")

View file

@ -729,9 +729,8 @@ class ChatService:
if not priceCHF or priceCHF <= 0: if not priceCHF or priceCHF <= 0:
return return
# Extract provider from model name (e.g., "anthropic.claude-3-sonnet" -> "anthropic") # Get provider from AiCallResponse (set from model.connectorType)
modelName = getattr(aiResponse, 'modelName', '') or '' aicoreProvider = getattr(aiResponse, 'provider', None) or 'unknown'
aicoreProvider = modelName.split('.')[0] if '.' in modelName else 'unknown'
# Get feature context if available # Get feature context if available
featureInstanceId = getattr(self.services, 'featureInstanceId', None) featureInstanceId = getattr(self.services, 'featureInstanceId', None)

View file

@ -229,6 +229,7 @@ class ExtractionService:
aiResponse = AiCallResponse( aiResponse = AiCallResponse(
content="", # No content for extraction stats needed content="", # No content for extraction stats needed
modelName=model.name, modelName=model.name,
provider=model.connectorType,
priceCHF=priceCHF, priceCHF=priceCHF,
processingTime=processingTime, processingTime=processingTime,
bytesSent=bytesSent, bytesSent=bytesSent,
@ -1311,6 +1312,7 @@ class ExtractionService:
return AiCallResponse( return AiCallResponse(
content=modelResponse.content, content=modelResponse.content,
modelName=model.name, modelName=model.name,
provider=model.connectorType,
priceCHF=0.0, priceCHF=0.0,
processingTime=processingTime, processingTime=processingTime,
bytesSent=0, bytesSent=0,
@ -1416,6 +1418,7 @@ class ExtractionService:
return AiCallResponse( return AiCallResponse(
content=mergedContent, content=mergedContent,
modelName=model.name, modelName=model.name,
provider=model.connectorType,
priceCHF=sum(r.priceCHF for r in chunkResults), priceCHF=sum(r.priceCHF for r in chunkResults),
processingTime=sum(r.processingTime for r in chunkResults), processingTime=sum(r.processingTime for r in chunkResults),
bytesSent=sum(r.bytesSent for r in chunkResults), bytesSent=sum(r.bytesSent for r in chunkResults),
@ -1428,49 +1431,6 @@ class ExtractionService:
response = await aiObjects._callWithModel(model, prompt, contentPart.data, options) response = await aiObjects._callWithModel(model, prompt, contentPart.data, options)
logger.info(f"✅ Content part processed successfully with model: {model.name}") logger.info(f"✅ Content part processed successfully with model: {model.name}")
return response return response
chunks = await self.chunkContentPartForAi(contentPart, model, options, prompt)
if not chunks:
raise ValueError(f"Failed to chunk content part for model {model.name}")
logger.info(f"Starting to process {len(chunks)} chunks with model {model.name}")
if progressCallback:
progressCallback(0.0, f"Starting to process {len(chunks)} chunks")
chunkResults = []
for idx, chunk in enumerate(chunks):
chunkNum = idx + 1
chunkData = chunk.get('data', '')
logger.info(f"Processing chunk {chunkNum}/{len(chunks)} with model {model.name}")
if progressCallback:
progressCallback(chunkNum / len(chunks), f"Processing chunk {chunkNum}/{len(chunks)}")
try:
chunkResponse = await aiObjects._callWithModel(model, prompt, chunkData, options)
chunkResults.append(chunkResponse)
logger.info(f"✅ Chunk {chunkNum}/{len(chunks)} processed successfully")
if progressCallback:
progressCallback(chunkNum / len(chunks), f"Chunk {chunkNum}/{len(chunks)} processed")
except Exception as e:
logger.error(f"❌ Error processing chunk {chunkNum}/{len(chunks)}: {str(e)}")
raise
# Merge chunk results using unified mergePartResults
# Pass original contentPart to preserve typeGroup for all chunks (one-to-many: 1 part -> N chunks)
mergedContent = self.mergePartResults(chunkResults, options, [contentPart])
logger.info(f"✅ Content part chunked and processed with model: {model.name} ({len(chunks)} chunks)")
return AiCallResponse(
content=mergedContent,
modelName=model.name,
priceCHF=sum(r.priceCHF for r in chunkResults),
processingTime=sum(r.processingTime for r in chunkResults),
bytesSent=sum(r.bytesSent for r in chunkResults),
bytesReceived=sum(r.bytesReceived for r in chunkResults),
errorCount=sum(r.errorCount for r in chunkResults)
)
except Exception as e: except Exception as e:
lastError = e lastError = e