diff --git a/api_server.py b/api_server.py index 077269d..d99a3e1 100644 --- a/api_server.py +++ b/api_server.py @@ -11,11 +11,11 @@ from slowapi.util import get_remote_address import time import json import asyncio -from aiostream import stream from llm_client import LLMClient import aiohttp import base64 from concurrent.futures import ThreadPoolExecutor +import weakref logging.basicConfig( level=logging.DEBUG, @@ -166,6 +166,7 @@ class LlmClientBackend(BaseModelBackend): self._pool_lock = asyncio.Lock() self.logger = logging.getLogger("api.client") self._inference_executor = ThreadPoolExecutor(max_workers=self.POOL_SIZE) + self._active_tasks = weakref.WeakSet() async def _parse_content(self, content: Union[str, List[ContentItem]]) -> str: if isinstance(content, list): @@ -177,35 +178,32 @@ class LlmClientBackend(BaseModelBackend): url = item.image_url.get("url", "") if url.startswith("data:image"): base64_data = url.split(",", 1)[1] - text_parts.append(f"[图片Base64数据:{base64_data[:100]}...]") else: base64_str = await self.download_image(url) if base64_str: - text_parts.append(f"[网络图片Base64数据:{base64_str[:100]}...]") + pass return " ".join(text_parts) return str(content) async def _get_client(self, request): try: await asyncio.wait_for(self._pool_lock.acquire(), timeout=5.0) - # 尝试从池中获取可用连接 + if self._client_pool: client = self._client_pool.pop() self.logger.debug(f"♻️ Reusing client from pool | ID:{id(client)}") return client - # 检查是否达到最大连接数 if len(self._active_clients) >= self.POOL_SIZE: raise RuntimeError("Connection pool exhausted") - # 创建新连接 self.logger.debug("🆕 Creating new LLM client") client = LLMClient( host=self.config["host"], port=self.config["port"] ) self._active_clients[id(client)] = client - + system_content = next( (m.content for m in request.messages if m.role == "system"), self.config.get("system_prompt", "You are a helpful assistant") @@ -238,14 +236,15 @@ class LlmClientBackend(BaseModelBackend): async def _release_client(self, client): async with self._pool_lock: - # 将连接放回池中供后续使用 self._client_pool.append(client) - self.logger.debug(f"🔙 Returned client to pool | ID:{id(client)}") + self.logger.debug(f"Returned client to pool | ID:{id(client)}") async def inference_stream(self, query: str, request: ChatCompletionRequest): client = await self._get_client(request) + task = asyncio.current_task() + self._active_tasks.add(task) try: - self.logger.debug(f"📡 Starting inference | ClientID:{id(client)} Query length:{len(query)}") + self.logger.debug(f"Starting inference | ClientID:{id(client)} Query length:{len(query)}") loop = asyncio.get_event_loop() sync_gen = client.inference_stream( @@ -254,56 +253,34 @@ class LlmClientBackend(BaseModelBackend): ) while True: - try: - def get_next(): - try: - return next(sync_gen) - except StopIteration: - return None - - chunk = await loop.run_in_executor( - self._inference_executor, - get_next - ) - if chunk is None: - break - yield chunk - except Exception as e: - self.logger.error(f"Inference error: {str(e)}") - yield f"[ERROR: {str(e)}]" + if task.cancelled(): + client.stop_inference() break - finally: - await self._release_client(client) - - async def inference_jpeg(self, query: str, request: ChatCompletionRequest): - client = await self._get_client(request) - try: - self.logger.debug(f"📡 Starting inference | ClientID:{id(client)} Query length:{len(query)}") - - loop = asyncio.get_event_loop() - sync_img = client.inference_stream( - query, - object_type="vlm.jpeg.stream.base64" - ) - - while True: - try: - def get_next(): - try: - return next(sync_img) - except StopIteration: - return None - - chunk = await loop.run_in_executor(None, get_next) - if chunk is None: - break - yield chunk - except Exception as e: - self.logger.error(f"Inference error: {str(e)}") - yield f"[ERROR: {str(e)}]" + + def get_next(): + try: + return next(sync_gen) + except StopIteration: + return None + + chunk = await loop.run_in_executor( + self._inference_executor, + get_next + ) + if chunk is None: break + yield chunk + except asyncio.CancelledError: + self.logger.warning("Inference task cancelled, stopping...") + client.stop_inference() + raise + except Exception as e: + self.logger.error(f"Inference error: {str(e)}") + yield f"[ERROR: {str(e)}]" finally: + self._active_tasks.discard(task) await self._release_client(client) + self.logger.debug(f"Inference stopped | ClientID:{id(client)}") def _truncate_history(self, messages: List[Message]) -> List[Message]: """Truncate history to fit model context window""" @@ -342,6 +319,11 @@ class LlmClientBackend(BaseModelBackend): truncated_messages = self._truncate_history(request.messages) query_lines = [] + isJpeg = False + isUrl = False + base64_data = "" + base64_str = "" + for m in truncated_messages: if m.role == "system": continue @@ -355,13 +337,10 @@ class LlmClientBackend(BaseModelBackend): url = item.image_url.get("url", "") if url.startswith("data:image"): base64_data = url.split(",", 1)[1] - text_parts.append(f"[图片Base64数据:{base64_data[:100]}...]") # 截断防止过长 else: base64_str = await self.download_image(url) if base64_str: - text_parts.append(f"[网络图片Base64数据:{base64_str[:100]}...]") - else: - text_parts.append("[图片下载失败]") + pass combined_content = " ".join(text_parts) query_lines.append(f"{m.role}: {combined_content}") else: @@ -373,7 +352,7 @@ class LlmClientBackend(BaseModelBackend): f"Context truncated: Original {len(request.messages)} → Kept {len(truncated_messages)} " f"Total length:{len(query)} chars" ) - + if request.stream: async def chunk_generator(): try: @@ -480,10 +459,14 @@ async def chat_completions(request: Request, body: ChatCompletionRequest): json_chunk = json.dumps(chunk_dict, ensure_ascii=False) print(f"Sending chunk: {json_chunk}") yield f"data: {json_chunk}\n\n" - except Exception as e: - logger.error(f"Stream interrupted: {str(e)}") - yield f"data: {{'error': 'Stream interrupted'}}\n\n" - yield "data: [DONE]\n\n" + except asyncio.CancelledError: + logger.warning("客户端提前断开连接,正在终止推理...") + if backend and isinstance(backend, LlmClientBackend): + for task in backend._active_tasks: + task.cancel() + raise + finally: + logger.debug("流连接已关闭") return StreamingResponse( format_stream(), diff --git a/llm_client.py b/llm_client.py index 6c48fde..b5a6f54 100644 --- a/llm_client.py +++ b/llm_client.py @@ -78,6 +78,14 @@ class LLMClient: self.work_id = response["work_id"] break + def stop_inference(self) -> dict: + request_id = self._send_request("pause", "llm.utf-8", {}) + return request_id + + def send_jpeg(self, query: str, object_type: str = "vlm.jpeg.base64") -> str: + request_id = self._send_request("inference", object_type, query) + return request_id + def exit(self) -> dict: request_id = self._send_request("exit", "llm.utf-8", {}) result = self._wait_response(request_id) @@ -103,7 +111,7 @@ class LLMClient: if __name__ == "__main__": with LLMClient(host='192.168.20.183') as client: setup_response = client.setup("llm.setup", { - "model": "deepseek-r1-1.5B-ax630c", + "model": "Qwen2.5-0.5B-w8a16", "response_format": "llm.utf-8.stream", "input": "llm.utf-8", "enoutput": True, @@ -112,8 +120,9 @@ if __name__ == "__main__": }) print("Setup response:", setup_response) - for chunk in client.inference_stream("What's your name?"): + for chunk in client.inference_stream("给我讲一个故事"): print("Received chunk:", chunk) + client.stop_inference() exit_response = client.exit() print("Exit response:", exit_response) \ No newline at end of file