You've already forked ModuleLLM-OpenAI-Plugin
mirror of
https://github.com/m5stack/ModuleLLM-OpenAI-Plugin.git
synced 2026-05-20 11:37:26 -07:00
[feat] Add stop inference support
This commit is contained in:
+48
-65
@@ -11,11 +11,11 @@ from slowapi.util import get_remote_address
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from aiostream import stream
|
||||
from llm_client import LLMClient
|
||||
import aiohttp
|
||||
import base64
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import weakref
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
@@ -166,6 +166,7 @@ class LlmClientBackend(BaseModelBackend):
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self.logger = logging.getLogger("api.client")
|
||||
self._inference_executor = ThreadPoolExecutor(max_workers=self.POOL_SIZE)
|
||||
self._active_tasks = weakref.WeakSet()
|
||||
|
||||
async def _parse_content(self, content: Union[str, List[ContentItem]]) -> str:
|
||||
if isinstance(content, list):
|
||||
@@ -177,35 +178,32 @@ class LlmClientBackend(BaseModelBackend):
|
||||
url = item.image_url.get("url", "")
|
||||
if url.startswith("data:image"):
|
||||
base64_data = url.split(",", 1)[1]
|
||||
text_parts.append(f"[图片Base64数据:{base64_data[:100]}...]")
|
||||
else:
|
||||
base64_str = await self.download_image(url)
|
||||
if base64_str:
|
||||
text_parts.append(f"[网络图片Base64数据:{base64_str[:100]}...]")
|
||||
pass
|
||||
return " ".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
async def _get_client(self, request):
|
||||
try:
|
||||
await asyncio.wait_for(self._pool_lock.acquire(), timeout=5.0)
|
||||
# 尝试从池中获取可用连接
|
||||
|
||||
if self._client_pool:
|
||||
client = self._client_pool.pop()
|
||||
self.logger.debug(f"♻️ Reusing client from pool | ID:{id(client)}")
|
||||
return client
|
||||
|
||||
# 检查是否达到最大连接数
|
||||
if len(self._active_clients) >= self.POOL_SIZE:
|
||||
raise RuntimeError("Connection pool exhausted")
|
||||
|
||||
# 创建新连接
|
||||
self.logger.debug("🆕 Creating new LLM client")
|
||||
client = LLMClient(
|
||||
host=self.config["host"],
|
||||
port=self.config["port"]
|
||||
)
|
||||
self._active_clients[id(client)] = client
|
||||
|
||||
|
||||
system_content = next(
|
||||
(m.content for m in request.messages if m.role == "system"),
|
||||
self.config.get("system_prompt", "You are a helpful assistant")
|
||||
@@ -238,14 +236,15 @@ class LlmClientBackend(BaseModelBackend):
|
||||
|
||||
async def _release_client(self, client):
|
||||
async with self._pool_lock:
|
||||
# 将连接放回池中供后续使用
|
||||
self._client_pool.append(client)
|
||||
self.logger.debug(f"🔙 Returned client to pool | ID:{id(client)}")
|
||||
self.logger.debug(f"Returned client to pool | ID:{id(client)}")
|
||||
|
||||
async def inference_stream(self, query: str, request: ChatCompletionRequest):
|
||||
client = await self._get_client(request)
|
||||
task = asyncio.current_task()
|
||||
self._active_tasks.add(task)
|
||||
try:
|
||||
self.logger.debug(f"📡 Starting inference | ClientID:{id(client)} Query length:{len(query)}")
|
||||
self.logger.debug(f"Starting inference | ClientID:{id(client)} Query length:{len(query)}")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
sync_gen = client.inference_stream(
|
||||
@@ -254,56 +253,34 @@ class LlmClientBackend(BaseModelBackend):
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
def get_next():
|
||||
try:
|
||||
return next(sync_gen)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
chunk = await loop.run_in_executor(
|
||||
self._inference_executor,
|
||||
get_next
|
||||
)
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self.logger.error(f"Inference error: {str(e)}")
|
||||
yield f"[ERROR: {str(e)}]"
|
||||
if task.cancelled():
|
||||
client.stop_inference()
|
||||
break
|
||||
finally:
|
||||
await self._release_client(client)
|
||||
|
||||
async def inference_jpeg(self, query: str, request: ChatCompletionRequest):
|
||||
client = await self._get_client(request)
|
||||
try:
|
||||
self.logger.debug(f"📡 Starting inference | ClientID:{id(client)} Query length:{len(query)}")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
sync_img = client.inference_stream(
|
||||
query,
|
||||
object_type="vlm.jpeg.stream.base64"
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
def get_next():
|
||||
try:
|
||||
return next(sync_img)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
chunk = await loop.run_in_executor(None, get_next)
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self.logger.error(f"Inference error: {str(e)}")
|
||||
yield f"[ERROR: {str(e)}]"
|
||||
|
||||
def get_next():
|
||||
try:
|
||||
return next(sync_gen)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
chunk = await loop.run_in_executor(
|
||||
self._inference_executor,
|
||||
get_next
|
||||
)
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
self.logger.warning("Inference task cancelled, stopping...")
|
||||
client.stop_inference()
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Inference error: {str(e)}")
|
||||
yield f"[ERROR: {str(e)}]"
|
||||
finally:
|
||||
self._active_tasks.discard(task)
|
||||
await self._release_client(client)
|
||||
self.logger.debug(f"Inference stopped | ClientID:{id(client)}")
|
||||
|
||||
def _truncate_history(self, messages: List[Message]) -> List[Message]:
|
||||
"""Truncate history to fit model context window"""
|
||||
@@ -342,6 +319,11 @@ class LlmClientBackend(BaseModelBackend):
|
||||
truncated_messages = self._truncate_history(request.messages)
|
||||
|
||||
query_lines = []
|
||||
isJpeg = False
|
||||
isUrl = False
|
||||
base64_data = ""
|
||||
base64_str = ""
|
||||
|
||||
for m in truncated_messages:
|
||||
if m.role == "system":
|
||||
continue
|
||||
@@ -355,13 +337,10 @@ class LlmClientBackend(BaseModelBackend):
|
||||
url = item.image_url.get("url", "")
|
||||
if url.startswith("data:image"):
|
||||
base64_data = url.split(",", 1)[1]
|
||||
text_parts.append(f"[图片Base64数据:{base64_data[:100]}...]") # 截断防止过长
|
||||
else:
|
||||
base64_str = await self.download_image(url)
|
||||
if base64_str:
|
||||
text_parts.append(f"[网络图片Base64数据:{base64_str[:100]}...]")
|
||||
else:
|
||||
text_parts.append("[图片下载失败]")
|
||||
pass
|
||||
combined_content = " ".join(text_parts)
|
||||
query_lines.append(f"{m.role}: {combined_content}")
|
||||
else:
|
||||
@@ -373,7 +352,7 @@ class LlmClientBackend(BaseModelBackend):
|
||||
f"Context truncated: Original {len(request.messages)} → Kept {len(truncated_messages)} "
|
||||
f"Total length:{len(query)} chars"
|
||||
)
|
||||
|
||||
|
||||
if request.stream:
|
||||
async def chunk_generator():
|
||||
try:
|
||||
@@ -480,10 +459,14 @@ async def chat_completions(request: Request, body: ChatCompletionRequest):
|
||||
json_chunk = json.dumps(chunk_dict, ensure_ascii=False)
|
||||
print(f"Sending chunk: {json_chunk}")
|
||||
yield f"data: {json_chunk}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Stream interrupted: {str(e)}")
|
||||
yield f"data: {{'error': 'Stream interrupted'}}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("客户端提前断开连接,正在终止推理...")
|
||||
if backend and isinstance(backend, LlmClientBackend):
|
||||
for task in backend._active_tasks:
|
||||
task.cancel()
|
||||
raise
|
||||
finally:
|
||||
logger.debug("流连接已关闭")
|
||||
|
||||
return StreamingResponse(
|
||||
format_stream(),
|
||||
|
||||
+11
-2
@@ -78,6 +78,14 @@ class LLMClient:
|
||||
self.work_id = response["work_id"]
|
||||
break
|
||||
|
||||
def stop_inference(self) -> dict:
|
||||
request_id = self._send_request("pause", "llm.utf-8", {})
|
||||
return request_id
|
||||
|
||||
def send_jpeg(self, query: str, object_type: str = "vlm.jpeg.base64") -> str:
|
||||
request_id = self._send_request("inference", object_type, query)
|
||||
return request_id
|
||||
|
||||
def exit(self) -> dict:
|
||||
request_id = self._send_request("exit", "llm.utf-8", {})
|
||||
result = self._wait_response(request_id)
|
||||
@@ -103,7 +111,7 @@ class LLMClient:
|
||||
if __name__ == "__main__":
|
||||
with LLMClient(host='192.168.20.183') as client:
|
||||
setup_response = client.setup("llm.setup", {
|
||||
"model": "deepseek-r1-1.5B-ax630c",
|
||||
"model": "Qwen2.5-0.5B-w8a16",
|
||||
"response_format": "llm.utf-8.stream",
|
||||
"input": "llm.utf-8",
|
||||
"enoutput": True,
|
||||
@@ -112,8 +120,9 @@ if __name__ == "__main__":
|
||||
})
|
||||
print("Setup response:", setup_response)
|
||||
|
||||
for chunk in client.inference_stream("What's your name?"):
|
||||
for chunk in client.inference_stream("给我讲一个故事"):
|
||||
print("Received chunk:", chunk)
|
||||
client.stop_inference()
|
||||
|
||||
exit_response = client.exit()
|
||||
print("Exit response:", exit_response)
|
||||
Reference in New Issue
Block a user