331 lines
13 KiB
Python
331 lines
13 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
User-scoped voice settings and TTS/STT catalog endpoints.
|
|
|
|
Uses modules.interfaces.interfaceVoiceObjects (voice core) and persists preferences
|
|
via UserVoicePreferences — same domain as routeVoiceGoogle (Google connector ops).
|
|
"""
|
|
|
|
import base64
|
|
import logging
|
|
from typing import Any, Dict
|
|
|
|
from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, status
|
|
|
|
from modules.auth import getCurrentUser, limiter
|
|
from modules.datamodels.datamodelUam import User, UserVoicePreferences, _normalizeTtsVoiceMap
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.interfaces.interfaceVoiceObjects import getVoiceInterface
|
|
from modules.shared.i18nRegistry import apiRouteContext
|
|
routeApiMsg = apiRouteContext("routeVoiceUser")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(
|
|
prefix="/api/voice",
|
|
tags=["Voice User"],
|
|
responses={
|
|
404: {"description": "Not found"},
|
|
400: {"description": "Bad request"},
|
|
401: {"description": "Unauthorized"},
|
|
403: {"description": "Forbidden"},
|
|
500: {"description": "Internal server error"},
|
|
},
|
|
)
|
|
|
|
|
|
@router.get("/preferences")
|
|
@limiter.limit("60/minute")
|
|
def getVoicePreferences(
|
|
request: Request,
|
|
currentUser: User = Depends(getCurrentUser),
|
|
) -> Dict[str, Any]:
|
|
"""Get user's voice/language preferences (optionally scoped to mandate via header)."""
|
|
rootInterface = getRootInterface()
|
|
mandateId = request.headers.get("X-Mandate-Id") or None
|
|
userId = str(currentUser.id)
|
|
|
|
prefs = rootInterface.db.getRecordset(
|
|
UserVoicePreferences,
|
|
recordFilter={"userId": userId, "mandateId": mandateId},
|
|
)
|
|
if prefs:
|
|
return prefs[0] if isinstance(prefs[0], dict) else prefs[0].model_dump()
|
|
return UserVoicePreferences(userId=userId, mandateId=mandateId).model_dump()
|
|
|
|
|
|
@router.put("/preferences")
|
|
@limiter.limit("30/minute")
|
|
def updateVoicePreferences(
|
|
request: Request,
|
|
preferences: Dict[str, Any] = Body(...),
|
|
currentUser: User = Depends(getCurrentUser),
|
|
) -> Dict[str, Any]:
|
|
"""Update user's voice/language preferences (upsert)."""
|
|
rootInterface = getRootInterface()
|
|
mandateId = request.headers.get("X-Mandate-Id") or None
|
|
userId = str(currentUser.id)
|
|
|
|
existing = rootInterface.db.getRecordset(
|
|
UserVoicePreferences,
|
|
recordFilter={"userId": userId, "mandateId": mandateId},
|
|
)
|
|
|
|
allowedFields = {
|
|
"sttLanguage",
|
|
"ttsLanguage",
|
|
"ttsVoice",
|
|
"ttsVoiceMap",
|
|
"translationSourceLanguage",
|
|
"translationTargetLanguage",
|
|
}
|
|
updateData = {k: v for k, v in preferences.items() if k in allowedFields}
|
|
if "ttsVoiceMap" in updateData:
|
|
updateData["ttsVoiceMap"] = _normalizeTtsVoiceMap(updateData["ttsVoiceMap"])
|
|
|
|
if existing:
|
|
existingRecord = existing[0]
|
|
existingId = existingRecord.get("id") if isinstance(existingRecord, dict) else existingRecord.id
|
|
rootInterface.db.recordModify(UserVoicePreferences, existingId, updateData)
|
|
updated = rootInterface.db.getRecordset(UserVoicePreferences, recordFilter={"id": existingId})
|
|
return updated[0] if updated else {"message": "Updated", **updateData}
|
|
newPrefs = UserVoicePreferences(userId=userId, mandateId=mandateId, **updateData)
|
|
created = rootInterface.db.recordCreate(UserVoicePreferences, newPrefs.model_dump())
|
|
return created if isinstance(created, dict) else created.model_dump()
|
|
|
|
|
|
@router.get("/languages")
|
|
@limiter.limit("120/minute")
|
|
async def getVoiceLanguages(
|
|
request: Request,
|
|
currentUser: User = Depends(getCurrentUser),
|
|
) -> Dict[str, Any]:
|
|
"""Return available TTS languages (user-level, no instance context needed)."""
|
|
voiceInterface = getVoiceInterface(currentUser)
|
|
languagesResult = await voiceInterface.getAvailableLanguages()
|
|
languageList = languagesResult.get("languages", []) if isinstance(languagesResult, dict) else languagesResult
|
|
return {"languages": languageList}
|
|
|
|
|
|
@router.get("/voices")
|
|
@limiter.limit("120/minute")
|
|
async def getVoiceVoices(
|
|
request: Request,
|
|
language: str = Query("de-DE"),
|
|
currentUser: User = Depends(getCurrentUser),
|
|
) -> Dict[str, Any]:
|
|
"""Return available TTS voices for a given language."""
|
|
voiceInterface = getVoiceInterface(currentUser)
|
|
voicesResult = await voiceInterface.getAvailableVoices(language)
|
|
voiceList = voicesResult.get("voices", []) if isinstance(voicesResult, dict) else voicesResult
|
|
return {"voices": voiceList}
|
|
|
|
|
|
# Same minimum as modules.serviceCenter.services.serviceAi.mainServiceAi._checkBillingBeforeAiCall
|
|
_MIN_AI_BILLING_ESTIMATE_CHF = 0.01
|
|
|
|
|
|
def _userMandateIds(rootInterface, currentUser: User):
|
|
memberships = rootInterface.getUserMandates(str(currentUser.id))
|
|
out = []
|
|
for um in memberships:
|
|
mid = getattr(um, "mandateId", None) or (um.get("mandateId") if isinstance(um, dict) else None)
|
|
if mid:
|
|
out.append(str(mid))
|
|
return list(dict.fromkeys(out))
|
|
|
|
|
|
def _mandatePassesAiPoolBilling(currentUser: User, mandateId: str, userId: str) -> bool:
|
|
"""True if mandate pool passes the same billing gate as AI calls (subscription + pool >= estimate)."""
|
|
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
|
|
|
|
bi = getBillingInterface(currentUser, mandateId)
|
|
res = bi.checkBalance(mandateId, userId, _MIN_AI_BILLING_ESTIMATE_CHF)
|
|
return bool(res.allowed)
|
|
|
|
|
|
def _mandatePoolBalanceChf(currentUser: User, mandateId: str) -> float:
|
|
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
|
|
|
|
bi = getBillingInterface(currentUser, mandateId)
|
|
acc = bi.getMandateAccount(mandateId)
|
|
if not acc:
|
|
return 0.0
|
|
return float(acc.get("balance", 0.0) or 0.0)
|
|
|
|
|
|
def _resolveMandateIdForVoiceTestAi(request: Request, currentUser: User) -> str:
|
|
"""
|
|
AI sample billing uses mandate pool (PREPAY), not per-user wallet.
|
|
Prefer X-Mandate-Id when the user is a member and that mandate's pool can pay;
|
|
otherwise pick the member mandate with the highest pool balance that passes the AI billing check.
|
|
"""
|
|
rootInterface = getRootInterface()
|
|
userId = str(currentUser.id)
|
|
memberIds = _userMandateIds(rootInterface, currentUser)
|
|
if not memberIds:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=(
|
|
"Voice test needs at least one mandate membership for AI billing. "
|
|
"Join a mandate or open the app from a mandate context."
|
|
),
|
|
)
|
|
|
|
headerRaw = (request.headers.get("X-Mandate-Id") or request.headers.get("x-mandate-id") or "").strip()
|
|
if headerRaw:
|
|
if headerRaw not in memberIds:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=routeApiMsg("X-Mandate-Id is not a mandate you belong to."),
|
|
)
|
|
if _mandatePassesAiPoolBilling(currentUser, headerRaw, userId):
|
|
logger.info(
|
|
"Voice test AI billing: using header mandate %s (pool ok for estimate %.4f CHF)",
|
|
headerRaw,
|
|
_MIN_AI_BILLING_ESTIMATE_CHF,
|
|
)
|
|
return headerRaw
|
|
logger.warning(
|
|
"Voice test AI billing: header mandate %s has insufficient mandate pool or subscription; "
|
|
"trying other memberships",
|
|
headerRaw,
|
|
)
|
|
|
|
bestMid = None
|
|
bestBal = -1.0
|
|
for mid in memberIds:
|
|
if not _mandatePassesAiPoolBilling(currentUser, mid, userId):
|
|
continue
|
|
bal = _mandatePoolBalanceChf(currentUser, mid)
|
|
if bal > bestBal:
|
|
bestBal = bal
|
|
bestMid = mid
|
|
|
|
if bestMid:
|
|
logger.info(
|
|
"Voice test AI billing: selected mandate %s (mandate pool %.2f CHF, estimate %.4f CHF)",
|
|
bestMid,
|
|
bestBal,
|
|
_MIN_AI_BILLING_ESTIMATE_CHF,
|
|
)
|
|
return bestMid
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail=(
|
|
"No mandate you belong to has sufficient shared pool balance for AI (or subscription inactive). "
|
|
"Top up the mandate pool or use a mandate with budget."
|
|
),
|
|
)
|
|
|
|
|
|
def _sanitizeAiTtsSample(raw: str) -> str:
|
|
s = (raw or "").strip()
|
|
if s.startswith("```"):
|
|
nl = s.find("\n")
|
|
if nl != -1:
|
|
s = s[nl + 1 :]
|
|
if s.rstrip().endswith("```"):
|
|
s = s.rstrip()[:-3].strip()
|
|
if len(s) >= 2 and ((s[0] == s[-1] == '"') or (s[0] == s[-1] == "'")):
|
|
s = s[1:-1].strip()
|
|
return s
|
|
|
|
|
|
async def _generateTtsSampleTextForLocale(
|
|
request: Request,
|
|
currentUser: User,
|
|
localeTag: str,
|
|
) -> str:
|
|
from modules.serviceCenter import getService
|
|
from modules.serviceCenter.context import ServiceCenterContext
|
|
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum, ProcessingModeEnum
|
|
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import (
|
|
BillingContextError,
|
|
InsufficientBalanceException,
|
|
ProviderNotAllowedException,
|
|
)
|
|
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionInactiveException
|
|
|
|
mandateId = _resolveMandateIdForVoiceTestAi(request, currentUser)
|
|
ctx = ServiceCenterContext(user=currentUser, mandate_id=mandateId, feature_instance_id=None)
|
|
aiService = getService("ai", ctx)
|
|
|
|
systemPrompt = (
|
|
"You write short text-to-speech demo lines for end users.\n"
|
|
"Task: Output exactly one or two natural sentences a user would enjoy hearing when testing a voice.\n"
|
|
"The entire output MUST be written ONLY in the natural spoken language that matches the given "
|
|
"BCP-47 locale tag. Do not use any other language.\n"
|
|
"Do not mention locales, tags, tests, artificial intelligence, or these instructions.\n"
|
|
"No quotation marks around the text. No markdown. Plain text only."
|
|
)
|
|
userPrompt = f"BCP-47 locale tag: `{localeTag}`.\nWrite the sample now."
|
|
|
|
aiRequest = AiCallRequest(
|
|
prompt=userPrompt,
|
|
context=systemPrompt,
|
|
requireNeutralization=False,
|
|
options=AiCallOptions(
|
|
operationType=OperationTypeEnum.DATA_GENERATE,
|
|
priority=PriorityEnum.SPEED,
|
|
processingMode=ProcessingModeEnum.BASIC,
|
|
compressPrompt=False,
|
|
compressContext=False,
|
|
temperature=0.75,
|
|
maxParts=1,
|
|
),
|
|
)
|
|
try:
|
|
response = await aiService.callAi(aiRequest)
|
|
except SubscriptionInactiveException as e:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.message) from e
|
|
except InsufficientBalanceException as e:
|
|
raise HTTPException(status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=str(e)) from e
|
|
except ProviderNotAllowedException as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=getattr(e, "message", None) or str(e),
|
|
) from e
|
|
except BillingContextError as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
|
|
|
content = _sanitizeAiTtsSample(getattr(response, "content", None) or "")
|
|
if getattr(response, "errorCount", 0) or not content:
|
|
logger.warning("Voice test AI sample empty or errorCount=%s", getattr(response, "errorCount", None))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
detail=routeApiMsg("Could not generate voice test sample text."),
|
|
)
|
|
if len(content) > 500:
|
|
content = content[:500].rstrip()
|
|
return content
|
|
|
|
|
|
@router.post("/test")
|
|
@limiter.limit("30/minute")
|
|
async def testVoice(
|
|
request: Request,
|
|
body: Dict[str, Any] = Body(...),
|
|
currentUser: User = Depends(getCurrentUser),
|
|
) -> Dict[str, Any]:
|
|
"""Test a specific voice. Sample text is AI-generated in the voice locale unless `text` is supplied."""
|
|
textRaw = body.get("text")
|
|
language = body.get("language", "de-DE")
|
|
voiceId = body.get("voiceId")
|
|
|
|
text = (textRaw or "").strip() if isinstance(textRaw, str) else ""
|
|
if not text:
|
|
text = await _generateTtsSampleTextForLocale(request, currentUser, language)
|
|
|
|
voiceInterface = getVoiceInterface(currentUser)
|
|
result = await voiceInterface.textToSpeech(text=text, languageCode=language, voiceName=voiceId)
|
|
if result and isinstance(result, dict):
|
|
audioContent = result.get("audioContent")
|
|
if audioContent:
|
|
audioB64 = base64.b64encode(
|
|
audioContent if isinstance(audioContent, bytes) else audioContent.encode()
|
|
).decode()
|
|
return {"success": True, "audio": audioB64, "format": "mp3", "text": text}
|
|
return {"success": False, "error": "TTS returned no audio"}
|