platform-core/modules/routes/routeAdminSttBenchmark.py
2026-05-16 22:55:43 +02:00

217 lines
7.3 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""STT Benchmark route — compare Speech-to-Text v1 (latest_long) vs v2 (Chirp 2).
Sysadmin-only page for evaluating STT model quality and latency.
"""
import json
import time
import logging
from typing import Any, Dict
from fastapi import APIRouter, HTTPException, Depends, Request, UploadFile, File, Form
from modules.auth import limiter, getCurrentUser
from modules.datamodels.datamodelUam import User
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/admin/stt-benchmark",
tags=["Admin STT Benchmark"],
responses={401: {"description": "Unauthorized"}, 403: {"description": "Forbidden"}},
)
def _requireSysAdmin(currentUser: User = Depends(getCurrentUser)) -> User:
if not getattr(currentUser, "isSysAdmin", False) and not getattr(currentUser, "isPlatformAdmin", False):
raise HTTPException(status_code=403, detail="SysAdmin required")
return currentUser
def _getCredentials():
apiKey = APP_CONFIG.get("Connector_GoogleSpeech_API_KEY_SECRET")
if not apiKey or apiKey.startswith("YOUR_"):
raise HTTPException(status_code=500, detail="Google Speech API key not configured")
from google.oauth2 import service_account
return service_account.Credentials.from_service_account_info(json.loads(apiKey))
def _runV1(audioBytes: bytes, language: str, model: str) -> Dict[str, Any]:
"""Run Speech-to-Text v1 recognition."""
from google.cloud import speech
credentials = _getCredentials()
client = speech.SpeechClient(credentials=credentials)
config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.ENCODING_UNSPECIFIED,
language_code=language,
model=model,
enable_automatic_punctuation=True,
enable_word_time_offsets=True,
enable_word_confidence=True,
max_alternatives=3,
use_enhanced=True,
)
audio = speech.RecognitionAudio(content=audioBytes)
t0 = time.perf_counter()
response = client.recognize(config=config, audio=audio)
elapsed = time.perf_counter() - t0
results = []
for r in response.results:
for alt in r.alternatives:
results.append({
"transcript": alt.transcript,
"confidence": round(alt.confidence, 4),
"words": len(alt.words) if alt.words else 0,
})
return {
"api": "v1",
"model": model,
"latencyMs": round(elapsed * 1000, 1),
"results": results,
"resultCount": len(response.results),
}
def _runV2(audioBytes: bytes, language: str, model: str, location: str) -> Dict[str, Any]:
"""Run Speech-to-Text v2 recognition (Chirp 2)."""
from google.cloud.speech_v2 import SpeechClient
from google.cloud.speech_v2.types import cloud_speech
credentials = _getCredentials()
credInfo = json.loads(APP_CONFIG.get("Connector_GoogleSpeech_API_KEY_SECRET"))
projectId = credInfo.get("project_id", "")
client = SpeechClient(
credentials=credentials,
client_options={"api_endpoint": f"{location}-speech.googleapis.com"},
)
config = cloud_speech.RecognitionConfig(
auto_decoding_config=cloud_speech.AutoDetectDecodingConfig(),
language_codes=[language],
model=model,
features=cloud_speech.RecognitionFeatures(
enable_automatic_punctuation=True,
enable_word_time_offsets=True,
enable_word_confidence=True,
),
)
recognizer = f"projects/{projectId}/locations/{location}/recognizers/_"
request = cloud_speech.RecognizeRequest(
recognizer=recognizer,
config=config,
content=audioBytes,
)
t0 = time.perf_counter()
response = client.recognize(request=request)
elapsed = time.perf_counter() - t0
results = []
for r in response.results:
for alt in r.alternatives:
results.append({
"transcript": alt.transcript,
"confidence": round(alt.confidence, 4),
"words": len(alt.words) if alt.words else 0,
})
return {
"api": "v2",
"model": model,
"location": location,
"latencyMs": round(elapsed * 1000, 1),
"results": results,
"resultCount": len(getattr(response, "results", [])),
}
@router.post("/run")
@limiter.limit("10/minute")
async def runBenchmark(
request: Request,
file: UploadFile = File(...),
language: str = Form(default="de-DE"),
v1Model: str = Form(default="latest_long"),
v2Model: str = Form(default="chirp_2"),
v2Location: str = Form(default="europe-west4"),
currentUser: User = Depends(_requireSysAdmin),
) -> Dict[str, Any]:
"""Upload audio and compare v1 vs v2 STT results."""
audioBytes = await file.read()
if len(audioBytes) > 10 * 1024 * 1024:
raise HTTPException(status_code=400, detail="Audio file too large (max 10 MB)")
if len(audioBytes) < 100:
raise HTTPException(status_code=400, detail="Audio file too small")
logger.info("STT benchmark: %s, %d bytes, language=%s, v1=%s, v2=%s@%s",
file.filename, len(audioBytes), language, v1Model, v2Model, v2Location)
v1Result = None
v1Error = None
try:
v1Result = _runV1(audioBytes, language, v1Model)
except Exception as e:
v1Error = str(e)
logger.warning("STT v1 benchmark failed: %s", e)
v2Result = None
v2Error = None
try:
v2Result = _runV2(audioBytes, language, v2Model, v2Location)
except Exception as e:
v2Error = str(e)
logger.warning("STT v2 benchmark failed: %s", e)
return {
"filename": file.filename,
"fileSizeBytes": len(audioBytes),
"language": language,
"v1": v1Result or {"error": v1Error},
"v2": v2Result or {"error": v2Error},
}
@router.get("/models")
@limiter.limit("30/minute")
async def getAvailableModels(
request: Request,
currentUser: User = Depends(_requireSysAdmin),
) -> Dict[str, Any]:
"""Return available STT models for the benchmark UI."""
return {
"v1Models": [
{"value": "latest_long", "label": "latest_long (default)"},
{"value": "latest_short", "label": "latest_short"},
{"value": "phone_call", "label": "phone_call"},
{"value": "video", "label": "video"},
{"value": "command_and_search", "label": "command_and_search"},
],
"v2Models": [
{"value": "chirp_2", "label": "Chirp 2 (recommended)"},
{"value": "chirp", "label": "Chirp (original)"},
{"value": "long", "label": "long"},
{"value": "short", "label": "short"},
],
"locations": [
{"value": "europe-west4", "label": "Europe West (NL)"},
{"value": "us-central1", "label": "US Central"},
{"value": "asia-southeast1", "label": "Asia Southeast"},
],
"languages": [
{"value": "de-DE", "label": "Deutsch (DE)"},
{"value": "de-CH", "label": "Deutsch (CH)"},
{"value": "en-US", "label": "English (US)"},
{"value": "en-GB", "label": "English (GB)"},
{"value": "fr-FR", "label": "Francais (FR)"},
{"value": "it-IT", "label": "Italiano (IT)"},
],
}