You've already forked ModuleLLM-OpenAI-Plugin
mirror of
https://github.com/m5stack/ModuleLLM-OpenAI-Plugin.git
synced 2026-05-20 11:37:26 -07:00
[perf] Optimize code directory structure
This commit is contained in:
+8
-369
@@ -3,19 +3,17 @@ import uuid
|
||||
import yaml
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Union
|
||||
import logging
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from llm_client import LLMClient
|
||||
import aiohttp
|
||||
import base64
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import weakref
|
||||
from backend.test_backend import TestBackend
|
||||
from backend.openai_proxy_backend import OpenAIProxyBackend
|
||||
from backend.llm_client_backend import LlmClientBackend
|
||||
from backend.vision_model_backend import VisionModelBackend
|
||||
from backend.chat_schemas import ChatCompletionRequest, CompletionRequest, Message, ContentItem
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
@@ -31,36 +29,11 @@ limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
with open("config.yaml") as f:
|
||||
with open("config/config.yaml") as f:
|
||||
self.data = yaml.safe_load(f)
|
||||
|
||||
config = Config()
|
||||
|
||||
class ContentItem(BaseModel):
|
||||
type: str # text/image_url
|
||||
text: Optional[str] = None
|
||||
image_url: Optional[dict] = None
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: Union[str, List[ContentItem]]
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 0.9
|
||||
max_tokens: Optional[int] = 1000
|
||||
stream: Optional[bool] = False
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 0.9
|
||||
max_tokens: Optional[int] = 1000
|
||||
stream: Optional[bool] = False
|
||||
|
||||
@app.middleware("http")
|
||||
async def auth_middleware(request: Request, call_next):
|
||||
if request.url.path.startswith("/v1"):
|
||||
@@ -72,342 +45,6 @@ async def auth_middleware(request: Request, call_next):
|
||||
)
|
||||
return await call_next(request)
|
||||
|
||||
class BaseModelBackend:
|
||||
def __init__(self, model_config):
|
||||
self.config = model_config
|
||||
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
raise NotImplementedError
|
||||
|
||||
class TestBackend(BaseModelBackend):
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
if request.stream:
|
||||
async def chunk_generator():
|
||||
content_parts = ["🤣", "👉🏻", "🤡"]
|
||||
messages=[m.model_dump() for m in request.messages]
|
||||
print(f"messages:_____________{messages}______________")
|
||||
for i, part in enumerate(content_parts):
|
||||
yield {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": part,
|
||||
"role": "assistant" if i == 0 else None,
|
||||
"function_call": None,
|
||||
"tool_calls": None
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop" if i == len(content_parts)-1 else None
|
||||
}],
|
||||
"service_tier": None,
|
||||
"system_fingerprint": None,
|
||||
"usage": None
|
||||
}
|
||||
return chunk_generator()
|
||||
else:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "🤣👉🏻🤡",
|
||||
"function_call": None,
|
||||
"tool_calls": None
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
|
||||
class OpenAIProxyBackend(BaseModelBackend):
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(
|
||||
api_key=self.config["api_key"],
|
||||
base_url=self.config["base_url"]
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=self.config["model"],
|
||||
messages=[m.model_dump() for m in request.messages],
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
stream=request.stream
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
async def async_wrapper():
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
return async_wrapper()
|
||||
return response
|
||||
|
||||
class LlmClientBackend(BaseModelBackend):
|
||||
MAX_CONTEXT_LENGTH = 500
|
||||
POOL_SIZE = 2
|
||||
|
||||
def __init__(self, model_config):
|
||||
super().__init__(model_config)
|
||||
self._client_pool = []
|
||||
self._active_clients = {}
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self.logger = logging.getLogger("api.client")
|
||||
self._inference_executor = ThreadPoolExecutor(max_workers=self.POOL_SIZE)
|
||||
self._active_tasks = weakref.WeakSet()
|
||||
|
||||
async def _parse_content(self, content: Union[str, List[ContentItem]], base64_images: list) -> str:
|
||||
text_parts = []
|
||||
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.type == "text" and item.text:
|
||||
text_parts.append(item.text.strip())
|
||||
elif item.type == "image_url" and item.image_url:
|
||||
url = item.image_url.get("url", "")
|
||||
if url.startswith("data:image"):
|
||||
base64_data = url.split(",", 1)[1]
|
||||
base64_images.append(base64_data)
|
||||
else:
|
||||
base64_str = await self.download_image(url)
|
||||
if base64_str:
|
||||
base64_images.append(base64_str)
|
||||
else:
|
||||
text_parts.append(str(content).strip())
|
||||
|
||||
return " ".join(text_parts).strip()
|
||||
|
||||
async def _get_client(self, request):
|
||||
try:
|
||||
await asyncio.wait_for(self._pool_lock.acquire(), timeout=30.0)
|
||||
|
||||
if self._client_pool:
|
||||
client = self._client_pool.pop()
|
||||
self.logger.debug(f"♻️ Reusing client from pool | ID:{id(client)}")
|
||||
return client
|
||||
|
||||
if len(self._active_clients) >= self.POOL_SIZE:
|
||||
raise RuntimeError("Connection pool exhausted")
|
||||
|
||||
self.logger.debug("Creating new LLM client")
|
||||
client = LLMClient(
|
||||
host=self.config["host"],
|
||||
port=self.config["port"]
|
||||
)
|
||||
self._active_clients[id(client)] = client
|
||||
|
||||
system_content = next(
|
||||
(m.content for m in request.messages if m.role == "system"),
|
||||
self.config.get("system_prompt", "You are a helpful assistant")
|
||||
)
|
||||
parsed_prompt = await self._parse_content(system_content, [])
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: client.setup(
|
||||
self.config["object"],
|
||||
{
|
||||
"model": self.config["model_name"],
|
||||
"response_format": self.config["response_format"],
|
||||
"input": self.config["input"],
|
||||
"enoutput": True,
|
||||
"max_token_len": request.max_tokens,
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"prompt": parsed_prompt
|
||||
}
|
||||
)
|
||||
)
|
||||
return client
|
||||
finally:
|
||||
self._pool_lock.release()
|
||||
|
||||
async def _release_client(self, client):
|
||||
async with self._pool_lock:
|
||||
self._client_pool.append(client)
|
||||
self.logger.debug(f"Returned client to pool | ID:{id(client)}")
|
||||
|
||||
async def inference_stream(self, query: str, base64_images: list, request: ChatCompletionRequest):
|
||||
client = await self._get_client(request)
|
||||
task = asyncio.current_task()
|
||||
self._active_tasks.add(task)
|
||||
try:
|
||||
self.logger.debug(f"Starting inference | ClientID:{id(client)} Query length:{len(query)}")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for i, image_data in enumerate(base64_images):
|
||||
message = client.send_jpeg(image_data, object_type="vlm.jpeg.base64")
|
||||
print(f"发送第 {i+1} 张JPEG数据: {message[:20]}...")
|
||||
|
||||
sync_gen = client.inference_stream(
|
||||
query,
|
||||
object_type="llm.utf-8"
|
||||
)
|
||||
|
||||
while True:
|
||||
if task.cancelled():
|
||||
client.stop_inference()
|
||||
break
|
||||
|
||||
def get_next():
|
||||
try:
|
||||
return next(sync_gen)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
chunk = await loop.run_in_executor(
|
||||
self._inference_executor,
|
||||
get_next
|
||||
)
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
self.logger.warning("Inference task cancelled, stopping...")
|
||||
client.stop_inference()
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Inference error: {str(e)}")
|
||||
yield f"[ERROR: {str(e)}]"
|
||||
finally:
|
||||
self._active_tasks.discard(task)
|
||||
await self._release_client(client)
|
||||
self.logger.debug(f"Inference stopped | ClientID:{id(client)}")
|
||||
|
||||
def _truncate_history(self, messages: List[Message]) -> List[Message]:
|
||||
"""Truncate history to fit model context window"""
|
||||
total_length = 0
|
||||
keep_messages = []
|
||||
|
||||
# Process in reverse to keep latest messages
|
||||
for msg in reversed(messages):
|
||||
if msg.role == "system": # Always keep system messages
|
||||
keep_messages.insert(0, msg)
|
||||
continue
|
||||
|
||||
msg_length = len(msg.content)
|
||||
if total_length + msg_length > self.MAX_CONTEXT_LENGTH:
|
||||
break
|
||||
total_length += msg_length
|
||||
keep_messages.insert(0, msg) # Maintain original order
|
||||
|
||||
return keep_messages
|
||||
|
||||
async def download_image(self, url):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
image_data = await response.read()
|
||||
return base64.b64encode(image_data).decode('utf-8')
|
||||
self.logger.error(f"图片下载失败,状态码:{response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(f"图片下载异常:{str(e)}")
|
||||
return None
|
||||
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
try:
|
||||
truncated_messages = self._truncate_history(request.messages)
|
||||
|
||||
query_lines = []
|
||||
base64_images = []
|
||||
system_prompt = ""
|
||||
|
||||
for m in truncated_messages:
|
||||
if m.role == "system":
|
||||
system_content = await self._parse_content(m.content, base64_images)
|
||||
system_prompt += f"{system_content}\n"
|
||||
continue
|
||||
|
||||
message_content = await self._parse_content(m.content, base64_images)
|
||||
if message_content:
|
||||
query_lines.append(f"{m.role}: {message_content}")
|
||||
|
||||
final_query = []
|
||||
if system_prompt:
|
||||
final_query.append(system_prompt.strip())
|
||||
if base64_images:
|
||||
pass
|
||||
# final_query.append("\n".join([f"[IMAGE:{img[:20]}...]" for img in base64_images]))
|
||||
final_query.append("\n".join(query_lines))
|
||||
|
||||
query = "\n\n".join(filter(None, final_query))
|
||||
|
||||
self.logger.debug(
|
||||
f"Processed query | System prompt: {len(system_prompt)} chars | "
|
||||
f"Images: {len(base64_images)} | Dialogue lines: {len(query_lines)}"
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
async def chunk_generator():
|
||||
try:
|
||||
async for chunk in self.inference_stream(query, base64_images, request):
|
||||
yield {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": chunk},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
# Add normal completion marker
|
||||
yield {
|
||||
"choices": [{
|
||||
"delta": {},
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"Stream generation error: {str(e)}")
|
||||
yield {
|
||||
"error": {
|
||||
"message": f"Stream generation failed: {str(e)}",
|
||||
"type": "api_error"
|
||||
}
|
||||
}
|
||||
yield {"choices": [{"delta": {}, "finish_reason": "stop"}]}
|
||||
raise
|
||||
return chunk_generator()
|
||||
else:
|
||||
full_response = ""
|
||||
async for chunk in self.inference_stream(query, base64_images, request):
|
||||
full_response += chunk
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": full_response
|
||||
}
|
||||
}]
|
||||
}
|
||||
except RuntimeError as e:
|
||||
self.logger.error(f"Connection error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Model service connection failed: {str(e)}"
|
||||
)
|
||||
|
||||
class ModelDispatcher:
|
||||
def __init__(self):
|
||||
self.backends = {}
|
||||
@@ -421,6 +58,8 @@ class ModelDispatcher:
|
||||
self.backends[model_name] = LlmClientBackend(model_config)
|
||||
elif model_config["type"] == "llama.cpp":
|
||||
self.backends[model_name] = TestBackend(model_config)
|
||||
elif model_config["type"] == "vision_model":
|
||||
self.backends[model_name] = VisionModelBackend(model_config)
|
||||
|
||||
def get_backend(self, model_name):
|
||||
return self.backends.get(model_name)
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Union
|
||||
from .chat_schemas import ChatCompletionRequest # Note: You'll need to move the request models to a schemas.py file
|
||||
|
||||
class BaseModelBackend:
|
||||
def __init__(self, model_config):
|
||||
self.config = model_config
|
||||
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,27 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Union
|
||||
|
||||
class ContentItem(BaseModel):
|
||||
type: str # text/image_url
|
||||
text: Optional[str] = None
|
||||
image_url: Optional[dict] = None
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: Union[str, List[ContentItem]]
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 0.9
|
||||
max_tokens: Optional[int] = 1000
|
||||
stream: Optional[bool] = False
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 0.9
|
||||
max_tokens: Optional[int] = 1000
|
||||
stream: Optional[bool] = False
|
||||
@@ -0,0 +1,266 @@
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
import weakref
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from .base_model_backend import BaseModelBackend
|
||||
from .chat_schemas import ChatCompletionRequest, Message, ContentItem
|
||||
from client.llm_client import LLMClient
|
||||
import aiohttp
|
||||
import base64
|
||||
import logging
|
||||
from fastapi import HTTPException
|
||||
from typing import Union, List
|
||||
|
||||
class LlmClientBackend(BaseModelBackend):
|
||||
MAX_CONTEXT_LENGTH = 500
|
||||
POOL_SIZE = 2
|
||||
|
||||
def __init__(self, model_config):
|
||||
super().__init__(model_config)
|
||||
self._client_pool = []
|
||||
self._active_clients = {}
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self.logger = logging.getLogger("api.client")
|
||||
self._inference_executor = ThreadPoolExecutor(max_workers=self.POOL_SIZE)
|
||||
self._active_tasks = weakref.WeakSet()
|
||||
|
||||
async def _parse_content(self, content: Union[str, List[ContentItem]], base64_images: list) -> str:
|
||||
text_parts = []
|
||||
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.type == "text" and item.text:
|
||||
text_parts.append(item.text.strip())
|
||||
elif item.type == "image_url" and item.image_url:
|
||||
url = item.image_url.get("url", "")
|
||||
if url.startswith("data:image"):
|
||||
base64_data = url.split(",", 1)[1]
|
||||
base64_images.append(base64_data)
|
||||
else:
|
||||
base64_str = await self.download_image(url)
|
||||
if base64_str:
|
||||
base64_images.append(base64_str)
|
||||
else:
|
||||
text_parts.append(str(content).strip())
|
||||
|
||||
return " ".join(text_parts).strip()
|
||||
|
||||
async def _get_client(self, request):
|
||||
try:
|
||||
await asyncio.wait_for(self._pool_lock.acquire(), timeout=30.0)
|
||||
|
||||
if self._client_pool:
|
||||
client = self._client_pool.pop()
|
||||
self.logger.debug(f"♻️ Reusing client from pool | ID:{id(client)}")
|
||||
return client
|
||||
|
||||
if len(self._active_clients) >= self.POOL_SIZE:
|
||||
raise RuntimeError("Connection pool exhausted")
|
||||
|
||||
self.logger.debug("Creating new LLM client")
|
||||
client = LLMClient(
|
||||
host=self.config["host"],
|
||||
port=self.config["port"]
|
||||
)
|
||||
self._active_clients[id(client)] = client
|
||||
|
||||
system_content = next(
|
||||
(m.content for m in request.messages if m.role == "system"),
|
||||
self.config.get("system_prompt", "You are a helpful assistant")
|
||||
)
|
||||
parsed_prompt = await self._parse_content(system_content, [])
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: client.setup(
|
||||
self.config["object"],
|
||||
{
|
||||
"model": self.config["model_name"],
|
||||
"response_format": self.config["response_format"],
|
||||
"input": self.config["input"],
|
||||
"enoutput": True,
|
||||
"max_token_len": request.max_tokens,
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"prompt": parsed_prompt
|
||||
}
|
||||
)
|
||||
)
|
||||
return client
|
||||
finally:
|
||||
self._pool_lock.release()
|
||||
|
||||
async def _release_client(self, client):
|
||||
async with self._pool_lock:
|
||||
self._client_pool.append(client)
|
||||
self.logger.debug(f"Returned client to pool | ID:{id(client)}")
|
||||
|
||||
async def inference_stream(self, query: str, base64_images: list, request: ChatCompletionRequest):
|
||||
client = await self._get_client(request)
|
||||
task = asyncio.current_task()
|
||||
self._active_tasks.add(task)
|
||||
try:
|
||||
self.logger.debug(f"Starting inference | ClientID:{id(client)} Query length:{len(query)}")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for i, image_data in enumerate(base64_images):
|
||||
message = client.send_jpeg(image_data, object_type="vlm.jpeg.base64")
|
||||
print(f"发送第 {i+1} 张JPEG数据: {message[:20]}...")
|
||||
|
||||
sync_gen = client.inference_stream(
|
||||
query,
|
||||
object_type="llm.utf-8"
|
||||
)
|
||||
|
||||
while True:
|
||||
if task.cancelled():
|
||||
client.stop_inference()
|
||||
break
|
||||
|
||||
def get_next():
|
||||
try:
|
||||
return next(sync_gen)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
chunk = await loop.run_in_executor(
|
||||
self._inference_executor,
|
||||
get_next
|
||||
)
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
self.logger.warning("Inference task cancelled, stopping...")
|
||||
client.stop_inference()
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Inference error: {str(e)}")
|
||||
yield f"[ERROR: {str(e)}]"
|
||||
finally:
|
||||
self._active_tasks.discard(task)
|
||||
await self._release_client(client)
|
||||
self.logger.debug(f"Inference stopped | ClientID:{id(client)}")
|
||||
|
||||
def _truncate_history(self, messages: List[Message]) -> List[Message]:
|
||||
"""Truncate history to fit model context window"""
|
||||
total_length = 0
|
||||
keep_messages = []
|
||||
|
||||
# Process in reverse to keep latest messages
|
||||
for msg in reversed(messages):
|
||||
if msg.role == "system": # Always keep system messages
|
||||
keep_messages.insert(0, msg)
|
||||
continue
|
||||
|
||||
msg_length = len(msg.content)
|
||||
if total_length + msg_length > self.MAX_CONTEXT_LENGTH:
|
||||
break
|
||||
total_length += msg_length
|
||||
keep_messages.insert(0, msg) # Maintain original order
|
||||
|
||||
return keep_messages
|
||||
|
||||
async def download_image(self, url):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
image_data = await response.read()
|
||||
return base64.b64encode(image_data).decode('utf-8')
|
||||
self.logger.error(f"图片下载失败,状态码:{response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(f"图片下载异常:{str(e)}")
|
||||
return None
|
||||
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
try:
|
||||
truncated_messages = self._truncate_history(request.messages)
|
||||
|
||||
query_lines = []
|
||||
base64_images = []
|
||||
system_prompt = ""
|
||||
|
||||
for m in truncated_messages:
|
||||
if m.role == "system":
|
||||
system_content = await self._parse_content(m.content, base64_images)
|
||||
system_prompt += f"{system_content}\n"
|
||||
continue
|
||||
|
||||
message_content = await self._parse_content(m.content, base64_images)
|
||||
if message_content:
|
||||
query_lines.append(f"{m.role}: {message_content}")
|
||||
|
||||
final_query = []
|
||||
if system_prompt:
|
||||
final_query.append(system_prompt.strip())
|
||||
if base64_images:
|
||||
pass
|
||||
# final_query.append("\n".join([f"[IMAGE:{img[:20]}...]" for img in base64_images]))
|
||||
final_query.append("\n".join(query_lines))
|
||||
|
||||
query = "\n\n".join(filter(None, final_query))
|
||||
|
||||
self.logger.debug(
|
||||
f"Processed query | System prompt: {len(system_prompt)} chars | "
|
||||
f"Images: {len(base64_images)} | Dialogue lines: {len(query_lines)}"
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
async def chunk_generator():
|
||||
try:
|
||||
async for chunk in self.inference_stream(query, base64_images, request):
|
||||
yield {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": chunk},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
# Add normal completion marker
|
||||
yield {
|
||||
"choices": [{
|
||||
"delta": {},
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"Stream generation error: {str(e)}")
|
||||
yield {
|
||||
"error": {
|
||||
"message": f"Stream generation failed: {str(e)}",
|
||||
"type": "api_error"
|
||||
}
|
||||
}
|
||||
yield {"choices": [{"delta": {}, "finish_reason": "stop"}]}
|
||||
raise
|
||||
return chunk_generator()
|
||||
else:
|
||||
full_response = ""
|
||||
async for chunk in self.inference_stream(query, base64_images, request):
|
||||
full_response += chunk
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": full_response
|
||||
}
|
||||
}]
|
||||
}
|
||||
except RuntimeError as e:
|
||||
self.logger.error(f"Connection error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Model service connection failed: {str(e)}"
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
import uuid
|
||||
import time
|
||||
from openai import AsyncOpenAI
|
||||
from .base_model_backend import BaseModelBackend
|
||||
from .chat_schemas import ChatCompletionRequest
|
||||
from fastapi import HTTPException
|
||||
|
||||
class OpenAIProxyBackend(BaseModelBackend):
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
from openai import AsyncOpenAI, APIError
|
||||
|
||||
try:
|
||||
client = AsyncOpenAI(
|
||||
api_key=self.config["api_key"],
|
||||
base_url=self.config["base_url"]
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=self.config["model"],
|
||||
messages=[m.model_dump() for m in request.messages],
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
stream=request.stream
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
async def async_wrapper():
|
||||
try:
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
except APIError as e:
|
||||
yield {
|
||||
"error": {
|
||||
"message": f"OpenAI API Error: {str(e)}",
|
||||
"type": "api_error"
|
||||
}
|
||||
}
|
||||
return async_wrapper()
|
||||
return response
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"OpenAI proxy error: {str(e)}"
|
||||
)
|
||||
@@ -0,0 +1,56 @@
|
||||
import uuid
|
||||
import time
|
||||
from .base_model_backend import BaseModelBackend
|
||||
from .chat_schemas import ChatCompletionRequest
|
||||
|
||||
class TestBackend(BaseModelBackend):
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
if request.stream:
|
||||
async def chunk_generator():
|
||||
content_parts = ["🤣", "👉🏻", "🤡"]
|
||||
messages=[m.model_dump() for m in request.messages]
|
||||
print(f"messages:_____________{messages}______________")
|
||||
for i, part in enumerate(content_parts):
|
||||
yield {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": part,
|
||||
"role": "assistant" if i == 0 else None,
|
||||
"function_call": None,
|
||||
"tool_calls": None
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop" if i == len(content_parts)-1 else None
|
||||
}],
|
||||
"service_tier": None,
|
||||
"system_fingerprint": None,
|
||||
"usage": None
|
||||
}
|
||||
return chunk_generator()
|
||||
else:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "🤣👉🏻🤡",
|
||||
"function_call": None,
|
||||
"tool_calls": None
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
import uuid
|
||||
import time
|
||||
from openai import AsyncOpenAI
|
||||
from .base_model_backend import BaseModelBackend
|
||||
from .chat_schemas import ChatCompletionRequest, Message, ContentItem
|
||||
from fastapi import HTTPException
|
||||
from typing import List
|
||||
|
||||
class VisionModelBackend(BaseModelBackend):
|
||||
MAX_IMAGE_SIZE = 4 * 1024 * 1024 # 4MB
|
||||
IMAGE_TIMEOUT = 15 # 秒
|
||||
|
||||
async def _process_image_content(self, content_item: ContentItem) -> dict:
|
||||
if not content_item.image_url:
|
||||
return None
|
||||
|
||||
url = content_item.image_url.get("url", "")
|
||||
if url.startswith("data:image"):
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": url}
|
||||
}
|
||||
|
||||
# 下载外部图片并转换为base64
|
||||
base64_str = await self.download_image(
|
||||
url,
|
||||
max_size=self.MAX_IMAGE_SIZE,
|
||||
timeout=self.IMAGE_TIMEOUT
|
||||
)
|
||||
if not base64_str:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"无法加载图片: {url}"
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_str}"
|
||||
}
|
||||
}
|
||||
|
||||
async def _build_messages(self, messages: List[Message]):
|
||||
processed_messages = []
|
||||
|
||||
for msg in messages:
|
||||
content = msg.content
|
||||
new_content = []
|
||||
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.type == "text":
|
||||
new_content.append({
|
||||
"type": "text",
|
||||
"text": item.text
|
||||
})
|
||||
elif item.type == "image_url":
|
||||
image_content = await self._process_image_content(item)
|
||||
if image_content:
|
||||
new_content.append(image_content)
|
||||
else:
|
||||
new_content.append({
|
||||
"type": "text",
|
||||
"text": str(content)
|
||||
})
|
||||
|
||||
processed_messages.append({
|
||||
"role": msg.role,
|
||||
"content": new_content
|
||||
})
|
||||
|
||||
return processed_messages
|
||||
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
try:
|
||||
client = AsyncOpenAI(
|
||||
api_key=self.config["api_key"],
|
||||
base_url=self.config["base_url"],
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
messages = await self._build_messages(request.messages)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=self.config["model"],
|
||||
messages=messages,
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
stream=request.stream
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
async def stream_wrapper():
|
||||
async for chunk in response:
|
||||
# 统一错误处理
|
||||
if isinstance(chunk, dict) and "error" in chunk:
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
# 转换为兼容格式
|
||||
yield {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk.choices[0].delta.content or "",
|
||||
"role": "assistant"
|
||||
},
|
||||
"finish_reason": chunk.choices[0].finish_reason
|
||||
}]
|
||||
}
|
||||
yield {"choices": [{"delta": {}, "finish_reason": "stop"}]}
|
||||
return stream_wrapper()
|
||||
|
||||
# 非流式响应添加usage信息
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response.choices[0].message.content
|
||||
}
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
"total_tokens": response.usage.total_tokens if response.usage else 0
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Vision model error: {str(e)}"
|
||||
)
|
||||
@@ -6,26 +6,19 @@ server:
|
||||
models:
|
||||
llama2-7b:
|
||||
type: llama.cpp
|
||||
path: ./models/llama-2-7b-chat.Q4_K_M.gguf
|
||||
params:
|
||||
n_ctx: 4096
|
||||
n_gpu_layers: 35
|
||||
|
||||
# OpenAI API
|
||||
gpt-3.5-turbo-proxy:
|
||||
type: openai_proxy
|
||||
api_key:
|
||||
api_key: sk-
|
||||
base_url: https://api.openai.com/v1
|
||||
model: gpt-3.5-turbo
|
||||
|
||||
# DeepSeek API
|
||||
deepseek-r1:
|
||||
type: openai_proxy
|
||||
api_key:
|
||||
api_key: sk-
|
||||
base_url: https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
model: deepseek-r1
|
||||
|
||||
# ModuleLLM API
|
||||
qwen2.5-0.5b:
|
||||
type: tcp_client
|
||||
host: "192.168.20.183"
|
||||
@@ -33,6 +26,8 @@ models:
|
||||
model_name: "qwen2.5-0.5B-prefill-20e"
|
||||
object: "llm.setup"
|
||||
pool_size: 2
|
||||
response_format: "llm.utf-8.stream"
|
||||
input: "llm.utf-8"
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
||||
@@ -70,4 +65,12 @@ models:
|
||||
response_format: "vlm.utf-8.stream"
|
||||
input: "vlm.utf-8"
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
You are a helpful assistant.
|
||||
|
||||
qwen-vl-plus:
|
||||
type: vision_model
|
||||
api_key: sk-
|
||||
base_url: https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
model: qwen-vl-plus
|
||||
max_image_size: 4194304
|
||||
image_timeout: 20
|
||||
Reference in New Issue
Block a user