[perf] Add audio segmentation and transcoding

This commit is contained in:
LittleMouse
2025-04-15 16:36:40 +08:00
parent 0bf4767ff0
commit c0bcfd4d4b
2 changed files with 60 additions and 29 deletions
+35 -7
View File
@@ -1,3 +1,4 @@
import io
import os
import uuid
import yaml
@@ -6,6 +7,7 @@ import time
import json
import asyncio
from pydub import AudioSegment
from fastapi import FastAPI, Request, HTTPException, File, Form, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse
from backend import (
@@ -281,16 +283,42 @@ async def create_transcription(
try:
audio_data = await file.read()
transcription = await backend.create_transcription(
audio_data,
language=language,
prompt=prompt
)
audio = AudioSegment.from_file(io.BytesIO(audio_data), format=file.filename.split('.')[-1])
target_sample_rate = 16000
target_channels = 1
target_sample_width = 2
if audio.frame_rate != target_sample_rate or audio.channels != target_channels or audio.sample_width != target_sample_width:
audio = audio.set_frame_rate(target_sample_rate).set_channels(target_channels).set_sample_width(target_sample_width)
segment_duration_ms = 30 * 1000
segments = [audio[i:i + segment_duration_ms] for i in range(0, len(audio), segment_duration_ms)]
transcription_results = []
for segment in segments:
segment_data = io.BytesIO()
segment.export(segment_data, format="wav")
segment_data.seek(0)
transcription = await backend.create_transcription(
segment_data.read(),
language=language,
prompt=prompt
)
transcription_results.append(transcription)
full_transcription = " ".join(transcription_results)
return JSONResponse(content={
"text": transcription,
"text": full_transcription,
"task": "transcribe",
"language": language,
"duration": 0
"duration": len(audio) / 1000.0,
"segments": len(segments),
"sample_rate": target_sample_rate,
"channels": target_channels,
"bit_depth": target_sample_width * 8
})
except Exception as e:
logger.error(f"Transcription error: {str(e)}")
+25 -22
View File
@@ -37,6 +37,9 @@ class ASRClientBackend(BaseModelBackend):
self.logger.debug(f"Reusing client from pool | ID:{id(client)}")
return client
if len(self._active_clients) < self.POOL_SIZE:
break
for task in self._active_tasks:
task.cancel()
@@ -44,29 +47,29 @@ class ASRClientBackend(BaseModelBackend):
await asyncio.sleep(retry_interval)
await asyncio.wait_for(self._pool_lock.acquire(), timeout=timeout - (time.time() - start_time))
# if "memory_required" in self.config:
# await self.memory_checker.check_memory(self.config["memory_required"])
self.logger.debug("Creating new LLM client")
client = ASRClient(
host=self.config["host"],
port=self.config["port"]
)
self._active_clients[id(client)] = client
if "memory_required" in self.config:
await self.memory_checker.check_memory(self.config["memory_required"])
self.logger.debug("Creating new LLM client")
client = ASRClient(
host=self.config["host"],
port=self.config["port"]
)
self._active_clients[id(client)] = client
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
client.setup,
"whisper.setup",
{
"model": self.config["model_name"],
"response_format": "asr.utf-8",
"input": "whisper.base64.stream",
"language": "zh",
"enoutput": True
}
)
return client
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
client.setup,
"whisper.setup",
{
"model": self.config["model_name"],
"response_format": "asr.utf-8",
"input": "whisper.base64.stream",
"language": "zh",
"enoutput": True
}
)
return client
except asyncio.TimeoutError:
raise RuntimeError("Server busy, please try again later.")
finally: