217 lines
7.3 KiB
Python
217 lines
7.3 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""STT Benchmark route — compare Speech-to-Text v1 (latest_long) vs v2 (Chirp 2).
|
|
|
|
Sysadmin-only page for evaluating STT model quality and latency.
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
import logging
|
|
from typing import Any, Dict
|
|
|
|
from fastapi import APIRouter, HTTPException, Depends, Request, UploadFile, File, Form
|
|
from modules.auth import limiter, getCurrentUser
|
|
from modules.datamodels.datamodelUam import User
|
|
from modules.shared.configuration import APP_CONFIG
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(
|
|
prefix="/api/admin/stt-benchmark",
|
|
tags=["Admin STT Benchmark"],
|
|
responses={401: {"description": "Unauthorized"}, 403: {"description": "Forbidden"}},
|
|
)
|
|
|
|
|
|
def _requireSysAdmin(currentUser: User = Depends(getCurrentUser)) -> User:
|
|
if not getattr(currentUser, "isSysAdmin", False) and not getattr(currentUser, "isPlatformAdmin", False):
|
|
raise HTTPException(status_code=403, detail="SysAdmin required")
|
|
return currentUser
|
|
|
|
|
|
def _getCredentials():
|
|
apiKey = APP_CONFIG.get("Connector_GoogleSpeech_API_KEY_SECRET")
|
|
if not apiKey or apiKey.startswith("YOUR_"):
|
|
raise HTTPException(status_code=500, detail="Google Speech API key not configured")
|
|
from google.oauth2 import service_account
|
|
return service_account.Credentials.from_service_account_info(json.loads(apiKey))
|
|
|
|
|
|
def _runV1(audioBytes: bytes, language: str, model: str) -> Dict[str, Any]:
|
|
"""Run Speech-to-Text v1 recognition."""
|
|
from google.cloud import speech
|
|
credentials = _getCredentials()
|
|
client = speech.SpeechClient(credentials=credentials)
|
|
|
|
config = speech.RecognitionConfig(
|
|
encoding=speech.RecognitionConfig.AudioEncoding.ENCODING_UNSPECIFIED,
|
|
language_code=language,
|
|
model=model,
|
|
enable_automatic_punctuation=True,
|
|
enable_word_time_offsets=True,
|
|
enable_word_confidence=True,
|
|
max_alternatives=3,
|
|
use_enhanced=True,
|
|
)
|
|
audio = speech.RecognitionAudio(content=audioBytes)
|
|
|
|
t0 = time.perf_counter()
|
|
response = client.recognize(config=config, audio=audio)
|
|
elapsed = time.perf_counter() - t0
|
|
|
|
results = []
|
|
for r in response.results:
|
|
for alt in r.alternatives:
|
|
results.append({
|
|
"transcript": alt.transcript,
|
|
"confidence": round(alt.confidence, 4),
|
|
"words": len(alt.words) if alt.words else 0,
|
|
})
|
|
|
|
return {
|
|
"api": "v1",
|
|
"model": model,
|
|
"latencyMs": round(elapsed * 1000, 1),
|
|
"results": results,
|
|
"resultCount": len(response.results),
|
|
}
|
|
|
|
|
|
def _runV2(audioBytes: bytes, language: str, model: str, location: str) -> Dict[str, Any]:
|
|
"""Run Speech-to-Text v2 recognition (Chirp 2)."""
|
|
from google.cloud.speech_v2 import SpeechClient
|
|
from google.cloud.speech_v2.types import cloud_speech
|
|
|
|
credentials = _getCredentials()
|
|
credInfo = json.loads(APP_CONFIG.get("Connector_GoogleSpeech_API_KEY_SECRET"))
|
|
projectId = credInfo.get("project_id", "")
|
|
|
|
client = SpeechClient(
|
|
credentials=credentials,
|
|
client_options={"api_endpoint": f"{location}-speech.googleapis.com"},
|
|
)
|
|
|
|
config = cloud_speech.RecognitionConfig(
|
|
auto_decoding_config=cloud_speech.AutoDetectDecodingConfig(),
|
|
language_codes=[language],
|
|
model=model,
|
|
features=cloud_speech.RecognitionFeatures(
|
|
enable_automatic_punctuation=True,
|
|
enable_word_time_offsets=True,
|
|
enable_word_confidence=True,
|
|
),
|
|
)
|
|
|
|
recognizer = f"projects/{projectId}/locations/{location}/recognizers/_"
|
|
|
|
request = cloud_speech.RecognizeRequest(
|
|
recognizer=recognizer,
|
|
config=config,
|
|
content=audioBytes,
|
|
)
|
|
|
|
t0 = time.perf_counter()
|
|
response = client.recognize(request=request)
|
|
elapsed = time.perf_counter() - t0
|
|
|
|
results = []
|
|
for r in response.results:
|
|
for alt in r.alternatives:
|
|
results.append({
|
|
"transcript": alt.transcript,
|
|
"confidence": round(alt.confidence, 4),
|
|
"words": len(alt.words) if alt.words else 0,
|
|
})
|
|
|
|
return {
|
|
"api": "v2",
|
|
"model": model,
|
|
"location": location,
|
|
"latencyMs": round(elapsed * 1000, 1),
|
|
"results": results,
|
|
"resultCount": len(getattr(response, "results", [])),
|
|
}
|
|
|
|
|
|
@router.post("/run")
|
|
@limiter.limit("10/minute")
|
|
async def runBenchmark(
|
|
request: Request,
|
|
file: UploadFile = File(...),
|
|
language: str = Form(default="de-DE"),
|
|
v1Model: str = Form(default="latest_long"),
|
|
v2Model: str = Form(default="chirp_2"),
|
|
v2Location: str = Form(default="europe-west4"),
|
|
currentUser: User = Depends(_requireSysAdmin),
|
|
) -> Dict[str, Any]:
|
|
"""Upload audio and compare v1 vs v2 STT results."""
|
|
audioBytes = await file.read()
|
|
if len(audioBytes) > 10 * 1024 * 1024:
|
|
raise HTTPException(status_code=400, detail="Audio file too large (max 10 MB)")
|
|
if len(audioBytes) < 100:
|
|
raise HTTPException(status_code=400, detail="Audio file too small")
|
|
|
|
logger.info("STT benchmark: %s, %d bytes, language=%s, v1=%s, v2=%s@%s",
|
|
file.filename, len(audioBytes), language, v1Model, v2Model, v2Location)
|
|
|
|
v1Result = None
|
|
v1Error = None
|
|
try:
|
|
v1Result = _runV1(audioBytes, language, v1Model)
|
|
except Exception as e:
|
|
v1Error = str(e)
|
|
logger.warning("STT v1 benchmark failed: %s", e)
|
|
|
|
v2Result = None
|
|
v2Error = None
|
|
try:
|
|
v2Result = _runV2(audioBytes, language, v2Model, v2Location)
|
|
except Exception as e:
|
|
v2Error = str(e)
|
|
logger.warning("STT v2 benchmark failed: %s", e)
|
|
|
|
return {
|
|
"filename": file.filename,
|
|
"fileSizeBytes": len(audioBytes),
|
|
"language": language,
|
|
"v1": v1Result or {"error": v1Error},
|
|
"v2": v2Result or {"error": v2Error},
|
|
}
|
|
|
|
|
|
@router.get("/models")
|
|
@limiter.limit("30/minute")
|
|
async def getAvailableModels(
|
|
request: Request,
|
|
currentUser: User = Depends(_requireSysAdmin),
|
|
) -> Dict[str, Any]:
|
|
"""Return available STT models for the benchmark UI."""
|
|
return {
|
|
"v1Models": [
|
|
{"value": "latest_long", "label": "latest_long (default)"},
|
|
{"value": "latest_short", "label": "latest_short"},
|
|
{"value": "phone_call", "label": "phone_call"},
|
|
{"value": "video", "label": "video"},
|
|
{"value": "command_and_search", "label": "command_and_search"},
|
|
],
|
|
"v2Models": [
|
|
{"value": "chirp_2", "label": "Chirp 2 (recommended)"},
|
|
{"value": "chirp", "label": "Chirp (original)"},
|
|
{"value": "long", "label": "long"},
|
|
{"value": "short", "label": "short"},
|
|
],
|
|
"locations": [
|
|
{"value": "europe-west4", "label": "Europe West (NL)"},
|
|
{"value": "us-central1", "label": "US Central"},
|
|
{"value": "asia-southeast1", "label": "Asia Southeast"},
|
|
],
|
|
"languages": [
|
|
{"value": "de-DE", "label": "Deutsch (DE)"},
|
|
{"value": "de-CH", "label": "Deutsch (CH)"},
|
|
{"value": "en-US", "label": "English (US)"},
|
|
{"value": "en-GB", "label": "English (GB)"},
|
|
{"value": "fr-FR", "label": "Francais (FR)"},
|
|
{"value": "it-IT", "label": "Italiano (IT)"},
|
|
],
|
|
}
|