diff --git a/api_server.py b/api_server.py index 6141bc4..4dcccc9 100644 --- a/api_server.py +++ b/api_server.py @@ -1,3 +1,4 @@ +import io import os import uuid import yaml @@ -6,6 +7,7 @@ import time import json import asyncio +from pydub import AudioSegment from fastapi import FastAPI, Request, HTTPException, File, Form, UploadFile from fastapi.responses import JSONResponse, StreamingResponse from backend import ( @@ -281,16 +283,42 @@ async def create_transcription( try: audio_data = await file.read() - transcription = await backend.create_transcription( - audio_data, - language=language, - prompt=prompt - ) + audio = AudioSegment.from_file(io.BytesIO(audio_data), format=file.filename.split('.')[-1]) + + target_sample_rate = 16000 + target_channels = 1 + target_sample_width = 2 + + if audio.frame_rate != target_sample_rate or audio.channels != target_channels or audio.sample_width != target_sample_width: + audio = audio.set_frame_rate(target_sample_rate).set_channels(target_channels).set_sample_width(target_sample_width) + + segment_duration_ms = 30 * 1000 + segments = [audio[i:i + segment_duration_ms] for i in range(0, len(audio), segment_duration_ms)] + + transcription_results = [] + for segment in segments: + segment_data = io.BytesIO() + segment.export(segment_data, format="wav") + segment_data.seek(0) + + transcription = await backend.create_transcription( + segment_data.read(), + language=language, + prompt=prompt + ) + transcription_results.append(transcription) + + full_transcription = " ".join(transcription_results) + return JSONResponse(content={ - "text": transcription, + "text": full_transcription, "task": "transcribe", "language": language, - "duration": 0 + "duration": len(audio) / 1000.0, + "segments": len(segments), + "sample_rate": target_sample_rate, + "channels": target_channels, + "bit_depth": target_sample_width * 8 }) except Exception as e: logger.error(f"Transcription error: {str(e)}") diff --git a/backend/asr_client_backend.py b/backend/asr_client_backend.py index 5169dde..f6499de 100644 --- a/backend/asr_client_backend.py +++ b/backend/asr_client_backend.py @@ -37,6 +37,9 @@ class ASRClientBackend(BaseModelBackend): self.logger.debug(f"Reusing client from pool | ID:{id(client)}") return client + if len(self._active_clients) < self.POOL_SIZE: + break + for task in self._active_tasks: task.cancel() @@ -44,29 +47,29 @@ class ASRClientBackend(BaseModelBackend): await asyncio.sleep(retry_interval) await asyncio.wait_for(self._pool_lock.acquire(), timeout=timeout - (time.time() - start_time)) - # if "memory_required" in self.config: - # await self.memory_checker.check_memory(self.config["memory_required"]) - self.logger.debug("Creating new LLM client") - client = ASRClient( - host=self.config["host"], - port=self.config["port"] - ) - self._active_clients[id(client)] = client + if "memory_required" in self.config: + await self.memory_checker.check_memory(self.config["memory_required"]) + self.logger.debug("Creating new LLM client") + client = ASRClient( + host=self.config["host"], + port=self.config["port"] + ) + self._active_clients[id(client)] = client - loop = asyncio.get_event_loop() - await loop.run_in_executor( - None, - client.setup, - "whisper.setup", - { - "model": self.config["model_name"], - "response_format": "asr.utf-8", - "input": "whisper.base64.stream", - "language": "zh", - "enoutput": True - } - ) - return client + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, + client.setup, + "whisper.setup", + { + "model": self.config["model_name"], + "response_format": "asr.utf-8", + "input": "whisper.base64.stream", + "language": "zh", + "enoutput": True + } + ) + return client except asyncio.TimeoutError: raise RuntimeError("Server busy, please try again later.") finally: