You've already forked ModuleLLM-OpenAI-Plugin
mirror of
https://github.com/m5stack/ModuleLLM-OpenAI-Plugin.git
synced 2026-05-20 11:37:26 -07:00
[perf] Add audio segmentation and transcoding
This commit is contained in:
+35
-7
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
import os
|
||||
import uuid
|
||||
import yaml
|
||||
@@ -6,6 +7,7 @@ import time
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from pydub import AudioSegment
|
||||
from fastapi import FastAPI, Request, HTTPException, File, Form, UploadFile
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from backend import (
|
||||
@@ -281,16 +283,42 @@ async def create_transcription(
|
||||
|
||||
try:
|
||||
audio_data = await file.read()
|
||||
transcription = await backend.create_transcription(
|
||||
audio_data,
|
||||
language=language,
|
||||
prompt=prompt
|
||||
)
|
||||
audio = AudioSegment.from_file(io.BytesIO(audio_data), format=file.filename.split('.')[-1])
|
||||
|
||||
target_sample_rate = 16000
|
||||
target_channels = 1
|
||||
target_sample_width = 2
|
||||
|
||||
if audio.frame_rate != target_sample_rate or audio.channels != target_channels or audio.sample_width != target_sample_width:
|
||||
audio = audio.set_frame_rate(target_sample_rate).set_channels(target_channels).set_sample_width(target_sample_width)
|
||||
|
||||
segment_duration_ms = 30 * 1000
|
||||
segments = [audio[i:i + segment_duration_ms] for i in range(0, len(audio), segment_duration_ms)]
|
||||
|
||||
transcription_results = []
|
||||
for segment in segments:
|
||||
segment_data = io.BytesIO()
|
||||
segment.export(segment_data, format="wav")
|
||||
segment_data.seek(0)
|
||||
|
||||
transcription = await backend.create_transcription(
|
||||
segment_data.read(),
|
||||
language=language,
|
||||
prompt=prompt
|
||||
)
|
||||
transcription_results.append(transcription)
|
||||
|
||||
full_transcription = " ".join(transcription_results)
|
||||
|
||||
return JSONResponse(content={
|
||||
"text": transcription,
|
||||
"text": full_transcription,
|
||||
"task": "transcribe",
|
||||
"language": language,
|
||||
"duration": 0
|
||||
"duration": len(audio) / 1000.0,
|
||||
"segments": len(segments),
|
||||
"sample_rate": target_sample_rate,
|
||||
"channels": target_channels,
|
||||
"bit_depth": target_sample_width * 8
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {str(e)}")
|
||||
|
||||
@@ -37,6 +37,9 @@ class ASRClientBackend(BaseModelBackend):
|
||||
self.logger.debug(f"Reusing client from pool | ID:{id(client)}")
|
||||
return client
|
||||
|
||||
if len(self._active_clients) < self.POOL_SIZE:
|
||||
break
|
||||
|
||||
for task in self._active_tasks:
|
||||
task.cancel()
|
||||
|
||||
@@ -44,29 +47,29 @@ class ASRClientBackend(BaseModelBackend):
|
||||
await asyncio.sleep(retry_interval)
|
||||
await asyncio.wait_for(self._pool_lock.acquire(), timeout=timeout - (time.time() - start_time))
|
||||
|
||||
# if "memory_required" in self.config:
|
||||
# await self.memory_checker.check_memory(self.config["memory_required"])
|
||||
self.logger.debug("Creating new LLM client")
|
||||
client = ASRClient(
|
||||
host=self.config["host"],
|
||||
port=self.config["port"]
|
||||
)
|
||||
self._active_clients[id(client)] = client
|
||||
if "memory_required" in self.config:
|
||||
await self.memory_checker.check_memory(self.config["memory_required"])
|
||||
self.logger.debug("Creating new LLM client")
|
||||
client = ASRClient(
|
||||
host=self.config["host"],
|
||||
port=self.config["port"]
|
||||
)
|
||||
self._active_clients[id(client)] = client
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
client.setup,
|
||||
"whisper.setup",
|
||||
{
|
||||
"model": self.config["model_name"],
|
||||
"response_format": "asr.utf-8",
|
||||
"input": "whisper.base64.stream",
|
||||
"language": "zh",
|
||||
"enoutput": True
|
||||
}
|
||||
)
|
||||
return client
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
client.setup,
|
||||
"whisper.setup",
|
||||
{
|
||||
"model": self.config["model_name"],
|
||||
"response_format": "asr.utf-8",
|
||||
"input": "whisper.base64.stream",
|
||||
"language": "zh",
|
||||
"enoutput": True
|
||||
}
|
||||
)
|
||||
return client
|
||||
except asyncio.TimeoutError:
|
||||
raise RuntimeError("Server busy, please try again later.")
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user