fixed billing transactions mapping and added reporting
This commit is contained in:
parent
e4d41965f3
commit
34deb5f23d
6 changed files with 273 additions and 77 deletions
|
|
@ -162,6 +162,7 @@ class AiCallResponse(BaseModel):
|
||||||
|
|
||||||
content: str = Field(description="AI response content")
|
content: str = Field(description="AI response content")
|
||||||
modelName: str = Field(description="Selected model name")
|
modelName: str = Field(description="Selected model name")
|
||||||
|
provider: str = Field(default="unknown", description="AI provider / connectorType (anthropic, openai, perplexity, etc.)")
|
||||||
priceCHF: float = Field(default=0.0, description="Calculated price in USD")
|
priceCHF: float = Field(default=0.0, description="Calculated price in USD")
|
||||||
processingTime: float = Field(default=0.0, description="Duration in seconds")
|
processingTime: float = Field(default=0.0, description="Duration in seconds")
|
||||||
bytesSent: int = Field(default=0, description="Input data size in bytes")
|
bytesSent: int = Field(default=0, description="Input data size in bytes")
|
||||||
|
|
|
||||||
|
|
@ -229,6 +229,7 @@ class AiObjects:
|
||||||
return AiCallResponse(
|
return AiCallResponse(
|
||||||
content=content,
|
content=content,
|
||||||
modelName=model.name,
|
modelName=model.name,
|
||||||
|
provider=model.connectorType,
|
||||||
priceCHF=priceCHF,
|
priceCHF=priceCHF,
|
||||||
processingTime=processingTime,
|
processingTime=processingTime,
|
||||||
bytesSent=inputBytes,
|
bytesSent=inputBytes,
|
||||||
|
|
|
||||||
|
|
@ -1174,7 +1174,7 @@ class BillingObjects:
|
||||||
|
|
||||||
def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Get all user-level transactions for specified mandates.
|
Get all transactions for specified mandates (both USER and MANDATE accounts).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
|
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
|
||||||
|
|
@ -1190,9 +1190,8 @@ class BillingObjects:
|
||||||
try:
|
try:
|
||||||
appInterface = getAppInterface(self.currentUser)
|
appInterface = getAppInterface(self.currentUser)
|
||||||
|
|
||||||
# Get all user accounts
|
# Get ALL accounts (both USER and MANDATE types) to cover all billing models
|
||||||
accountFilter = {"accountType": AccountTypeEnum.USER.value}
|
allAccounts = self.db.getRecordset(BillingAccount)
|
||||||
allAccounts = self.db.getRecordset(BillingAccount, recordFilter=accountFilter)
|
|
||||||
|
|
||||||
# Filter by mandate if specified
|
# Filter by mandate if specified
|
||||||
if mandateIds:
|
if mandateIds:
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,8 @@ from modules.auth import limiter, requireSysAdmin, getRequestContext, RequestCon
|
||||||
# Import billing components
|
# Import billing components
|
||||||
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
|
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
|
||||||
from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService
|
from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService
|
||||||
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||||
|
from modules.routes.routeDataUsers import _applyFiltersAndSort
|
||||||
from modules.datamodels.datamodelBilling import (
|
from modules.datamodels.datamodelBilling import (
|
||||||
BillingAccount,
|
BillingAccount,
|
||||||
BillingTransaction,
|
BillingTransaction,
|
||||||
|
|
@ -779,22 +781,212 @@ async def getUserViewBalances(
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/view/users/transactions", response_model=List[UserTransactionResponse])
|
class ViewStatisticsResponse(BaseModel):
|
||||||
|
"""Aggregated statistics across all user's mandates."""
|
||||||
|
totalCost: float = 0.0
|
||||||
|
transactionCount: int = 0
|
||||||
|
costByProvider: Dict[str, float] = {}
|
||||||
|
costByFeature: Dict[str, float] = {}
|
||||||
|
costByMandate: Dict[str, float] = {}
|
||||||
|
timeSeries: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/view/statistics")
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
async def getUserViewStatistics(
|
||||||
|
request: Request,
|
||||||
|
period: str = Query(default="month", description="Period: 'day' or 'month'"),
|
||||||
|
year: int = Query(default=None, description="Year"),
|
||||||
|
month: Optional[int] = Query(None, description="Month (1-12, required for period='day')"),
|
||||||
|
ctx: RequestContext = Depends(getRequestContext)
|
||||||
|
) -> ViewStatisticsResponse:
|
||||||
|
"""
|
||||||
|
Get aggregated usage statistics across all user's mandates.
|
||||||
|
- period='month': returns monthly time series for the given year
|
||||||
|
- period='day': returns daily time series for the given month/year
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
if year is None:
|
||||||
|
year = datetime.now().year
|
||||||
|
|
||||||
|
if period == "day" and not month:
|
||||||
|
month = datetime.now().month
|
||||||
|
|
||||||
|
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||||||
|
|
||||||
|
# Get all mandates the user has access to
|
||||||
|
if ctx.user.isSysAdmin:
|
||||||
|
mandateIds = None
|
||||||
|
else:
|
||||||
|
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||||||
|
appInterface = getAppInterface(ctx.user)
|
||||||
|
userMandates = appInterface.getUserMandates(ctx.user.id)
|
||||||
|
mandateIds = []
|
||||||
|
for um in userMandates:
|
||||||
|
mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
|
||||||
|
if mandateId:
|
||||||
|
mandateIds.append(mandateId)
|
||||||
|
if not mandateIds:
|
||||||
|
logger.warning("No mandate IDs found for user")
|
||||||
|
return ViewStatisticsResponse()
|
||||||
|
|
||||||
|
# Get all transactions
|
||||||
|
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
|
||||||
|
logger.info(f"View statistics: {len(allTransactions)} total transactions fetched for period={period}, year={year}, month={month}")
|
||||||
|
|
||||||
|
# Calculate date range
|
||||||
|
if period == "day":
|
||||||
|
startDate = date(year, month, 1)
|
||||||
|
if month == 12:
|
||||||
|
endDate = date(year + 1, 1, 1)
|
||||||
|
else:
|
||||||
|
endDate = date(year, month + 1, 1)
|
||||||
|
else:
|
||||||
|
startDate = date(year, 1, 1)
|
||||||
|
endDate = date(year + 1, 1, 1)
|
||||||
|
|
||||||
|
# Filter by date range and only DEBIT transactions
|
||||||
|
debits = []
|
||||||
|
skippedNoDate = 0
|
||||||
|
skippedDateRange = 0
|
||||||
|
skippedNotDebit = 0
|
||||||
|
|
||||||
|
for t in allTransactions:
|
||||||
|
createdAt = t.get("_createdAt")
|
||||||
|
if not createdAt:
|
||||||
|
skippedNoDate += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse date from various formats (DB stores as DOUBLE PRECISION / Unix timestamp)
|
||||||
|
txDate = None
|
||||||
|
if isinstance(createdAt, (int, float)):
|
||||||
|
txDate = datetime.fromtimestamp(createdAt).date()
|
||||||
|
elif isinstance(createdAt, datetime):
|
||||||
|
txDate = createdAt.date()
|
||||||
|
elif isinstance(createdAt, date) and not isinstance(createdAt, datetime):
|
||||||
|
txDate = createdAt
|
||||||
|
elif isinstance(createdAt, str):
|
||||||
|
try:
|
||||||
|
# Try as float string first (Unix timestamp)
|
||||||
|
txDate = datetime.fromtimestamp(float(createdAt)).date()
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
try:
|
||||||
|
txDate = datetime.fromisoformat(createdAt.replace("Z", "+00:00")).date()
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
skippedNoDate += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
skippedNoDate += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if txDate < startDate or txDate >= endDate:
|
||||||
|
skippedDateRange += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Compare transactionType - handle both string and enum
|
||||||
|
txType = t.get("transactionType")
|
||||||
|
txTypeStr = str(txType) if txType is not None else ""
|
||||||
|
if txTypeStr != "DEBIT" and txTypeStr != "TransactionTypeEnum.DEBIT":
|
||||||
|
# Also check .value for enum objects
|
||||||
|
txTypeValue = getattr(txType, 'value', txTypeStr)
|
||||||
|
if txTypeValue != "DEBIT":
|
||||||
|
skippedNotDebit += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
t["_txDate"] = txDate
|
||||||
|
debits.append(t)
|
||||||
|
|
||||||
|
logger.info(f"View statistics: {len(debits)} DEBIT transactions after filter. "
|
||||||
|
f"Skipped: noDate={skippedNoDate}, dateRange={skippedDateRange}, notDebit={skippedNotDebit}")
|
||||||
|
|
||||||
|
# Aggregate totals
|
||||||
|
totalCost = sum(t.get("amount", 0) for t in debits)
|
||||||
|
|
||||||
|
costByProvider: Dict[str, float] = {}
|
||||||
|
costByFeature: Dict[str, float] = {}
|
||||||
|
costByMandate: Dict[str, float] = {}
|
||||||
|
|
||||||
|
for t in debits:
|
||||||
|
provider = t.get("aicoreProvider") or "unknown"
|
||||||
|
costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0)
|
||||||
|
|
||||||
|
mandate = t.get("mandateName") or t.get("mandateId") or "unknown"
|
||||||
|
featureCode = t.get("featureCode") or "unknown"
|
||||||
|
featureKey = f"{mandate} / {featureCode}"
|
||||||
|
costByFeature[featureKey] = costByFeature.get(featureKey, 0) + t.get("amount", 0)
|
||||||
|
|
||||||
|
mandate = t.get("mandateName") or t.get("mandateId") or "unknown"
|
||||||
|
costByMandate[mandate] = costByMandate.get(mandate, 0) + t.get("amount", 0)
|
||||||
|
|
||||||
|
# Build time series (raw data only, no display logic)
|
||||||
|
timeSeries = []
|
||||||
|
if period == "day":
|
||||||
|
numDays = (endDate - startDate).days
|
||||||
|
for day in range(numDays):
|
||||||
|
d = startDate + timedelta(days=day)
|
||||||
|
dayCost = sum(t.get("amount", 0) for t in debits if t["_txDate"] == d)
|
||||||
|
dayCount = sum(1 for t in debits if t["_txDate"] == d)
|
||||||
|
if dayCost > 0 or dayCount > 0:
|
||||||
|
timeSeries.append({
|
||||||
|
"date": d.isoformat(),
|
||||||
|
"cost": round(dayCost, 4),
|
||||||
|
"count": dayCount
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
for m in range(1, 13):
|
||||||
|
mStart = date(year, m, 1)
|
||||||
|
mEnd = date(year, m + 1, 1) if m < 12 else date(year + 1, 1, 1)
|
||||||
|
monthCost = sum(t.get("amount", 0) for t in debits if mStart <= t["_txDate"] < mEnd)
|
||||||
|
monthCount = sum(1 for t in debits if mStart <= t["_txDate"] < mEnd)
|
||||||
|
timeSeries.append({
|
||||||
|
"date": f"{year}-{m:02d}",
|
||||||
|
"cost": round(monthCost, 4),
|
||||||
|
"count": monthCount
|
||||||
|
})
|
||||||
|
|
||||||
|
return ViewStatisticsResponse(
|
||||||
|
totalCost=round(totalCost, 4),
|
||||||
|
transactionCount=len(debits),
|
||||||
|
costByProvider=costByProvider,
|
||||||
|
costByFeature=costByFeature,
|
||||||
|
costByMandate=costByMandate,
|
||||||
|
timeSeries=timeSeries
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting view statistics: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/view/users/transactions", response_model=PaginatedResponse[UserTransactionResponse])
|
||||||
@limiter.limit("30/minute")
|
@limiter.limit("30/minute")
|
||||||
async def getUserViewTransactions(
|
async def getUserViewTransactions(
|
||||||
request: Request,
|
request: Request,
|
||||||
limit: int = Query(default=100, ge=1, le=1000),
|
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||||
ctx: RequestContext = Depends(getRequestContext)
|
ctx: RequestContext = Depends(getRequestContext)
|
||||||
):
|
) -> PaginatedResponse[UserTransactionResponse]:
|
||||||
"""
|
"""
|
||||||
Get user-level transactions.
|
Get user-level transactions with pagination support.
|
||||||
- SysAdmin: sees all user transactions across all mandates
|
- SysAdmin: sees all user transactions across all mandates
|
||||||
- MandateAdmin: sees user transactions for mandates they manage
|
- MandateAdmin: sees user transactions for mandates they manage
|
||||||
- Regular user: sees only their own transactions
|
- Regular user: sees only their own transactions
|
||||||
|
|
||||||
|
Query Parameters:
|
||||||
|
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||||||
|
|
||||||
|
# Parse pagination params
|
||||||
|
paginationParams = None
|
||||||
|
if pagination:
|
||||||
|
import json
|
||||||
|
paginationDict = json.loads(pagination)
|
||||||
|
paginationDict = normalize_pagination_dict(paginationDict)
|
||||||
|
paginationParams = PaginationParams(**paginationDict)
|
||||||
|
|
||||||
# Determine which mandates the user has access to
|
# Determine which mandates the user has access to
|
||||||
if ctx.user.isSysAdmin:
|
if ctx.user.isSysAdmin:
|
||||||
# SysAdmin sees all
|
# SysAdmin sees all
|
||||||
|
|
@ -812,34 +1004,78 @@ async def getUserViewTransactions(
|
||||||
mandateIds.append(mandateId)
|
mandateIds.append(mandateId)
|
||||||
|
|
||||||
if not mandateIds:
|
if not mandateIds:
|
||||||
return []
|
return PaginatedResponse(items=[], pagination=None)
|
||||||
|
|
||||||
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=limit)
|
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
|
||||||
|
|
||||||
# Non-admin users only see their own transactions
|
logger.debug(f"Found {len(allTransactions)} transactions for mandates {mandateIds}")
|
||||||
if not ctx.user.isSysAdmin:
|
|
||||||
allTransactions = [t for t in allTransactions if t.get("userId") == ctx.user.id]
|
|
||||||
|
|
||||||
result = []
|
# Convert to response objects as dicts for filtering/sorting
|
||||||
|
transactionDicts = []
|
||||||
for t in allTransactions:
|
for t in allTransactions:
|
||||||
result.append(UserTransactionResponse(
|
transactionDicts.append({
|
||||||
id=t.get("id"),
|
"id": t.get("id"),
|
||||||
accountId=t.get("accountId"),
|
"accountId": t.get("accountId"),
|
||||||
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
|
"transactionType": t.get("transactionType", "DEBIT"),
|
||||||
amount=t.get("amount", 0.0),
|
"amount": t.get("amount", 0.0),
|
||||||
description=t.get("description", ""),
|
"description": t.get("description", ""),
|
||||||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
"referenceType": t.get("referenceType"),
|
||||||
workflowId=t.get("workflowId"),
|
"workflowId": t.get("workflowId"),
|
||||||
featureCode=t.get("featureCode"),
|
"featureCode": t.get("featureCode"),
|
||||||
aicoreProvider=t.get("aicoreProvider"),
|
"aicoreProvider": t.get("aicoreProvider"),
|
||||||
createdAt=t.get("_createdAt"),
|
"createdAt": t.get("_createdAt"),
|
||||||
mandateId=t.get("mandateId"),
|
"mandateId": t.get("mandateId"),
|
||||||
mandateName=t.get("mandateName"),
|
"mandateName": t.get("mandateName"),
|
||||||
userId=t.get("userId"),
|
"userId": t.get("userId"),
|
||||||
userName=t.get("userName")
|
"userName": t.get("userName"),
|
||||||
))
|
})
|
||||||
|
|
||||||
return result
|
# Apply filters and sorting
|
||||||
|
filteredDicts = _applyFiltersAndSort(transactionDicts, paginationParams)
|
||||||
|
|
||||||
|
# Convert to response models
|
||||||
|
def _toResponse(d):
|
||||||
|
return UserTransactionResponse(
|
||||||
|
id=d.get("id"),
|
||||||
|
accountId=d.get("accountId"),
|
||||||
|
transactionType=TransactionTypeEnum(d.get("transactionType", "DEBIT")),
|
||||||
|
amount=d.get("amount", 0.0),
|
||||||
|
description=d.get("description", ""),
|
||||||
|
referenceType=ReferenceTypeEnum(d["referenceType"]) if d.get("referenceType") else None,
|
||||||
|
workflowId=d.get("workflowId"),
|
||||||
|
featureCode=d.get("featureCode"),
|
||||||
|
aicoreProvider=d.get("aicoreProvider"),
|
||||||
|
createdAt=d.get("createdAt"),
|
||||||
|
mandateId=d.get("mandateId"),
|
||||||
|
mandateName=d.get("mandateName"),
|
||||||
|
userId=d.get("userId"),
|
||||||
|
userName=d.get("userName")
|
||||||
|
)
|
||||||
|
|
||||||
|
if paginationParams:
|
||||||
|
import math
|
||||||
|
totalItems = len(filteredDicts)
|
||||||
|
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
||||||
|
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
||||||
|
endIdx = startIdx + paginationParams.pageSize
|
||||||
|
paginatedDicts = filteredDicts[startIdx:endIdx]
|
||||||
|
|
||||||
|
return PaginatedResponse(
|
||||||
|
items=[_toResponse(d) for d in paginatedDicts],
|
||||||
|
pagination=PaginationMetadata(
|
||||||
|
currentPage=paginationParams.page,
|
||||||
|
pageSize=paginationParams.pageSize,
|
||||||
|
totalItems=totalItems,
|
||||||
|
totalPages=totalPages,
|
||||||
|
sort=paginationParams.sort,
|
||||||
|
filters=paginationParams.filters
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return PaginatedResponse(
|
||||||
|
items=[_toResponse(d) for d in filteredDicts],
|
||||||
|
pagination=None
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting user view transactions: {e}")
|
logger.error(f"Error getting user view transactions: {e}")
|
||||||
|
|
|
||||||
|
|
@ -729,9 +729,8 @@ class ChatService:
|
||||||
if not priceCHF or priceCHF <= 0:
|
if not priceCHF or priceCHF <= 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Extract provider from model name (e.g., "anthropic.claude-3-sonnet" -> "anthropic")
|
# Get provider from AiCallResponse (set from model.connectorType)
|
||||||
modelName = getattr(aiResponse, 'modelName', '') or ''
|
aicoreProvider = getattr(aiResponse, 'provider', None) or 'unknown'
|
||||||
aicoreProvider = modelName.split('.')[0] if '.' in modelName else 'unknown'
|
|
||||||
|
|
||||||
# Get feature context if available
|
# Get feature context if available
|
||||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||||
|
|
|
||||||
|
|
@ -229,6 +229,7 @@ class ExtractionService:
|
||||||
aiResponse = AiCallResponse(
|
aiResponse = AiCallResponse(
|
||||||
content="", # No content for extraction stats needed
|
content="", # No content for extraction stats needed
|
||||||
modelName=model.name,
|
modelName=model.name,
|
||||||
|
provider=model.connectorType,
|
||||||
priceCHF=priceCHF,
|
priceCHF=priceCHF,
|
||||||
processingTime=processingTime,
|
processingTime=processingTime,
|
||||||
bytesSent=bytesSent,
|
bytesSent=bytesSent,
|
||||||
|
|
@ -1311,6 +1312,7 @@ class ExtractionService:
|
||||||
return AiCallResponse(
|
return AiCallResponse(
|
||||||
content=modelResponse.content,
|
content=modelResponse.content,
|
||||||
modelName=model.name,
|
modelName=model.name,
|
||||||
|
provider=model.connectorType,
|
||||||
priceCHF=0.0,
|
priceCHF=0.0,
|
||||||
processingTime=processingTime,
|
processingTime=processingTime,
|
||||||
bytesSent=0,
|
bytesSent=0,
|
||||||
|
|
@ -1416,6 +1418,7 @@ class ExtractionService:
|
||||||
return AiCallResponse(
|
return AiCallResponse(
|
||||||
content=mergedContent,
|
content=mergedContent,
|
||||||
modelName=model.name,
|
modelName=model.name,
|
||||||
|
provider=model.connectorType,
|
||||||
priceCHF=sum(r.priceCHF for r in chunkResults),
|
priceCHF=sum(r.priceCHF for r in chunkResults),
|
||||||
processingTime=sum(r.processingTime for r in chunkResults),
|
processingTime=sum(r.processingTime for r in chunkResults),
|
||||||
bytesSent=sum(r.bytesSent for r in chunkResults),
|
bytesSent=sum(r.bytesSent for r in chunkResults),
|
||||||
|
|
@ -1428,49 +1431,6 @@ class ExtractionService:
|
||||||
response = await aiObjects._callWithModel(model, prompt, contentPart.data, options)
|
response = await aiObjects._callWithModel(model, prompt, contentPart.data, options)
|
||||||
logger.info(f"✅ Content part processed successfully with model: {model.name}")
|
logger.info(f"✅ Content part processed successfully with model: {model.name}")
|
||||||
return response
|
return response
|
||||||
chunks = await self.chunkContentPartForAi(contentPart, model, options, prompt)
|
|
||||||
if not chunks:
|
|
||||||
raise ValueError(f"Failed to chunk content part for model {model.name}")
|
|
||||||
|
|
||||||
logger.info(f"Starting to process {len(chunks)} chunks with model {model.name}")
|
|
||||||
|
|
||||||
if progressCallback:
|
|
||||||
progressCallback(0.0, f"Starting to process {len(chunks)} chunks")
|
|
||||||
|
|
||||||
chunkResults = []
|
|
||||||
for idx, chunk in enumerate(chunks):
|
|
||||||
chunkNum = idx + 1
|
|
||||||
chunkData = chunk.get('data', '')
|
|
||||||
logger.info(f"Processing chunk {chunkNum}/{len(chunks)} with model {model.name}")
|
|
||||||
|
|
||||||
if progressCallback:
|
|
||||||
progressCallback(chunkNum / len(chunks), f"Processing chunk {chunkNum}/{len(chunks)}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunkResponse = await aiObjects._callWithModel(model, prompt, chunkData, options)
|
|
||||||
chunkResults.append(chunkResponse)
|
|
||||||
logger.info(f"✅ Chunk {chunkNum}/{len(chunks)} processed successfully")
|
|
||||||
|
|
||||||
if progressCallback:
|
|
||||||
progressCallback(chunkNum / len(chunks), f"Chunk {chunkNum}/{len(chunks)} processed")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ Error processing chunk {chunkNum}/{len(chunks)}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Merge chunk results using unified mergePartResults
|
|
||||||
# Pass original contentPart to preserve typeGroup for all chunks (one-to-many: 1 part -> N chunks)
|
|
||||||
mergedContent = self.mergePartResults(chunkResults, options, [contentPart])
|
|
||||||
|
|
||||||
logger.info(f"✅ Content part chunked and processed with model: {model.name} ({len(chunks)} chunks)")
|
|
||||||
return AiCallResponse(
|
|
||||||
content=mergedContent,
|
|
||||||
modelName=model.name,
|
|
||||||
priceCHF=sum(r.priceCHF for r in chunkResults),
|
|
||||||
processingTime=sum(r.processingTime for r in chunkResults),
|
|
||||||
bytesSent=sum(r.bytesSent for r in chunkResults),
|
|
||||||
bytesReceived=sum(r.bytesReceived for r in chunkResults),
|
|
||||||
errorCount=sum(r.errorCount for r in chunkResults)
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
lastError = e
|
lastError = e
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue