[feat] Add stop inference support

This commit is contained in:
LittleMouse
2025-02-18 17:13:25 +08:00
parent ad035c0cda
commit 7952c31d16
2 changed files with 59 additions and 67 deletions
+48 -65
View File
@@ -11,11 +11,11 @@ from slowapi.util import get_remote_address
import time
import json
import asyncio
from aiostream import stream
from llm_client import LLMClient
import aiohttp
import base64
from concurrent.futures import ThreadPoolExecutor
import weakref
logging.basicConfig(
level=logging.DEBUG,
@@ -166,6 +166,7 @@ class LlmClientBackend(BaseModelBackend):
self._pool_lock = asyncio.Lock()
self.logger = logging.getLogger("api.client")
self._inference_executor = ThreadPoolExecutor(max_workers=self.POOL_SIZE)
self._active_tasks = weakref.WeakSet()
async def _parse_content(self, content: Union[str, List[ContentItem]]) -> str:
if isinstance(content, list):
@@ -177,35 +178,32 @@ class LlmClientBackend(BaseModelBackend):
url = item.image_url.get("url", "")
if url.startswith("data:image"):
base64_data = url.split(",", 1)[1]
text_parts.append(f"[图片Base64数据:{base64_data[:100]}...]")
else:
base64_str = await self.download_image(url)
if base64_str:
text_parts.append(f"[网络图片Base64数据:{base64_str[:100]}...]")
pass
return " ".join(text_parts)
return str(content)
async def _get_client(self, request):
try:
await asyncio.wait_for(self._pool_lock.acquire(), timeout=5.0)
# 尝试从池中获取可用连接
if self._client_pool:
client = self._client_pool.pop()
self.logger.debug(f"♻️ Reusing client from pool | ID:{id(client)}")
return client
# 检查是否达到最大连接数
if len(self._active_clients) >= self.POOL_SIZE:
raise RuntimeError("Connection pool exhausted")
# 创建新连接
self.logger.debug("🆕 Creating new LLM client")
client = LLMClient(
host=self.config["host"],
port=self.config["port"]
)
self._active_clients[id(client)] = client
system_content = next(
(m.content for m in request.messages if m.role == "system"),
self.config.get("system_prompt", "You are a helpful assistant")
@@ -238,14 +236,15 @@ class LlmClientBackend(BaseModelBackend):
async def _release_client(self, client):
async with self._pool_lock:
# 将连接放回池中供后续使用
self._client_pool.append(client)
self.logger.debug(f"🔙 Returned client to pool | ID:{id(client)}")
self.logger.debug(f"Returned client to pool | ID:{id(client)}")
async def inference_stream(self, query: str, request: ChatCompletionRequest):
client = await self._get_client(request)
task = asyncio.current_task()
self._active_tasks.add(task)
try:
self.logger.debug(f"📡 Starting inference | ClientID:{id(client)} Query length:{len(query)}")
self.logger.debug(f"Starting inference | ClientID:{id(client)} Query length:{len(query)}")
loop = asyncio.get_event_loop()
sync_gen = client.inference_stream(
@@ -254,56 +253,34 @@ class LlmClientBackend(BaseModelBackend):
)
while True:
try:
def get_next():
try:
return next(sync_gen)
except StopIteration:
return None
chunk = await loop.run_in_executor(
self._inference_executor,
get_next
)
if chunk is None:
break
yield chunk
except Exception as e:
self.logger.error(f"Inference error: {str(e)}")
yield f"[ERROR: {str(e)}]"
if task.cancelled():
client.stop_inference()
break
finally:
await self._release_client(client)
async def inference_jpeg(self, query: str, request: ChatCompletionRequest):
client = await self._get_client(request)
try:
self.logger.debug(f"📡 Starting inference | ClientID:{id(client)} Query length:{len(query)}")
loop = asyncio.get_event_loop()
sync_img = client.inference_stream(
query,
object_type="vlm.jpeg.stream.base64"
)
while True:
try:
def get_next():
try:
return next(sync_img)
except StopIteration:
return None
chunk = await loop.run_in_executor(None, get_next)
if chunk is None:
break
yield chunk
except Exception as e:
self.logger.error(f"Inference error: {str(e)}")
yield f"[ERROR: {str(e)}]"
def get_next():
try:
return next(sync_gen)
except StopIteration:
return None
chunk = await loop.run_in_executor(
self._inference_executor,
get_next
)
if chunk is None:
break
yield chunk
except asyncio.CancelledError:
self.logger.warning("Inference task cancelled, stopping...")
client.stop_inference()
raise
except Exception as e:
self.logger.error(f"Inference error: {str(e)}")
yield f"[ERROR: {str(e)}]"
finally:
self._active_tasks.discard(task)
await self._release_client(client)
self.logger.debug(f"Inference stopped | ClientID:{id(client)}")
def _truncate_history(self, messages: List[Message]) -> List[Message]:
"""Truncate history to fit model context window"""
@@ -342,6 +319,11 @@ class LlmClientBackend(BaseModelBackend):
truncated_messages = self._truncate_history(request.messages)
query_lines = []
isJpeg = False
isUrl = False
base64_data = ""
base64_str = ""
for m in truncated_messages:
if m.role == "system":
continue
@@ -355,13 +337,10 @@ class LlmClientBackend(BaseModelBackend):
url = item.image_url.get("url", "")
if url.startswith("data:image"):
base64_data = url.split(",", 1)[1]
text_parts.append(f"[图片Base64数据:{base64_data[:100]}...]") # 截断防止过长
else:
base64_str = await self.download_image(url)
if base64_str:
text_parts.append(f"[网络图片Base64数据:{base64_str[:100]}...]")
else:
text_parts.append("[图片下载失败]")
pass
combined_content = " ".join(text_parts)
query_lines.append(f"{m.role}: {combined_content}")
else:
@@ -373,7 +352,7 @@ class LlmClientBackend(BaseModelBackend):
f"Context truncated: Original {len(request.messages)} → Kept {len(truncated_messages)} "
f"Total length:{len(query)} chars"
)
if request.stream:
async def chunk_generator():
try:
@@ -480,10 +459,14 @@ async def chat_completions(request: Request, body: ChatCompletionRequest):
json_chunk = json.dumps(chunk_dict, ensure_ascii=False)
print(f"Sending chunk: {json_chunk}")
yield f"data: {json_chunk}\n\n"
except Exception as e:
logger.error(f"Stream interrupted: {str(e)}")
yield f"data: {{'error': 'Stream interrupted'}}\n\n"
yield "data: [DONE]\n\n"
except asyncio.CancelledError:
logger.warning("客户端提前断开连接,正在终止推理...")
if backend and isinstance(backend, LlmClientBackend):
for task in backend._active_tasks:
task.cancel()
raise
finally:
logger.debug("流连接已关闭")
return StreamingResponse(
format_stream(),
+11 -2
View File
@@ -78,6 +78,14 @@ class LLMClient:
self.work_id = response["work_id"]
break
def stop_inference(self) -> dict:
request_id = self._send_request("pause", "llm.utf-8", {})
return request_id
def send_jpeg(self, query: str, object_type: str = "vlm.jpeg.base64") -> str:
request_id = self._send_request("inference", object_type, query)
return request_id
def exit(self) -> dict:
request_id = self._send_request("exit", "llm.utf-8", {})
result = self._wait_response(request_id)
@@ -103,7 +111,7 @@ class LLMClient:
if __name__ == "__main__":
with LLMClient(host='192.168.20.183') as client:
setup_response = client.setup("llm.setup", {
"model": "deepseek-r1-1.5B-ax630c",
"model": "Qwen2.5-0.5B-w8a16",
"response_format": "llm.utf-8.stream",
"input": "llm.utf-8",
"enoutput": True,
@@ -112,8 +120,9 @@ if __name__ == "__main__":
})
print("Setup response:", setup_response)
for chunk in client.inference_stream("What's your name?"):
for chunk in client.inference_stream("给我讲一个故事"):
print("Received chunk:", chunk)
client.stop_inference()
exit_response = client.exit()
print("Exit response:", exit_response)