From 21ab7c7ac734abcafeadbb9803b40e68395501f3 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Fri, 11 Apr 2025 17:25:56 +0800 Subject: [PATCH] [fix] Optimize the llm model switching method and fix the instance exit problem --- api_server.py | 17 +++++------------ backend/llm_client_backend.py | 17 +++++++++-------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/api_server.py b/api_server.py index 220d15d..f110d43 100644 --- a/api_server.py +++ b/api_server.py @@ -58,7 +58,7 @@ async def auth_middleware(request: Request, call_next): class ModelDispatcher: def __init__(self): self.backends = {} - self.llm_models = [] + self.llm_models = set() self.lock = asyncio.Lock() async def get_backend(self, model_name): @@ -70,21 +70,14 @@ class ModelDispatcher: if model_config["type"] == "openai_proxy": self.backends[model_name] = OpenAIProxyBackend(model_config) elif model_config["type"] in ("llm", "vlm"): - logger.debug(f"self.llm_models: {self.llm_models}") - if self.llm_models and model_name not in self.llm_models: - for old_model in self.llm_models: - old_instance = self.backends.pop(old_model, None) + if model_name not in self.llm_models: + for old_model_name in list(self.llm_models): + old_instance = self.backends.pop(old_model_name, None) if old_instance: await old_instance.close() self.llm_models.clear() - count = model_config["pool_size"] - while len(self.llm_models) >= count: - oldest_model = self.llm_models.pop(0) - old_instance = self.backends.pop(oldest_model, None) - if old_instance: - await old_instance.close() self.backends[model_name] = LlmClientBackend(model_config) - self.llm_models.append(model_name) + self.llm_models.add(model_name) elif model_config["type"] == "vision_model": self.backends[model_name] = VisionModelBackend(model_config) elif model_config["type"] == "tts": diff --git a/backend/llm_client_backend.py b/backend/llm_client_backend.py index 183aa8c..cf18ca2 100644 --- a/backend/llm_client_backend.py +++ b/backend/llm_client_backend.py @@ -123,14 +123,15 @@ class LlmClientBackend(BaseModelBackend): self.logger.debug(f"Returned client to pool | ID:{id(client)}") async def close(self): - async with self._pool_lock: - for task in self._active_tasks: - task.cancel() - for client in self._client_pool: - client.exit() - self._client_pool.clear() - self._active_clients.clear() - self._inference_executor.shutdown(wait=True) + for task in self._active_tasks: + task.cancel() + for client in self._active_clients.values(): + client.exit() + for client in self._client_pool: + client.exit() + self._client_pool.clear() + self._active_clients.clear() + self._inference_executor.shutdown(wait=False) async def inference_stream(self, query: str, base64_images: list, request: ChatCompletionRequest): client = await self._get_client(request)