You've already forked ModuleLLM-OpenAI-Plugin
mirror of
https://github.com/m5stack/ModuleLLM-OpenAI-Plugin.git
synced 2026-05-20 11:37:26 -07:00
555 lines
21 KiB
Python
555 lines
21 KiB
Python
import os
|
|
import uuid
|
|
import yaml
|
|
from fastapi import FastAPI, Request, HTTPException
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel
|
|
from typing import Optional, List, Union
|
|
import logging
|
|
from slowapi import Limiter
|
|
from slowapi.util import get_remote_address
|
|
import time
|
|
import json
|
|
import asyncio
|
|
from llm_client import LLMClient
|
|
import aiohttp
|
|
import base64
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import weakref
|
|
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler(),
|
|
]
|
|
)
|
|
logger = logging.getLogger("api")
|
|
|
|
app = FastAPI(title="OpenAI Compatible API Server")
|
|
limiter = Limiter(key_func=get_remote_address)
|
|
|
|
class Config:
|
|
def __init__(self):
|
|
with open("config.yaml") as f:
|
|
self.data = yaml.safe_load(f)
|
|
|
|
config = Config()
|
|
|
|
class ContentItem(BaseModel):
|
|
type: str # text/image_url
|
|
text: Optional[str] = None
|
|
image_url: Optional[dict] = None
|
|
|
|
class Message(BaseModel):
|
|
role: str
|
|
content: Union[str, List[ContentItem]]
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
model: str
|
|
messages: List[Message]
|
|
temperature: Optional[float] = 0.7
|
|
top_p: Optional[float] = 0.9
|
|
max_tokens: Optional[int] = 1000
|
|
stream: Optional[bool] = False
|
|
|
|
class CompletionRequest(BaseModel):
|
|
model: str
|
|
prompt: str
|
|
temperature: Optional[float] = 0.7
|
|
top_p: Optional[float] = 0.9
|
|
max_tokens: Optional[int] = 1000
|
|
stream: Optional[bool] = False
|
|
|
|
@app.middleware("http")
|
|
async def auth_middleware(request: Request, call_next):
|
|
if request.url.path.startswith("/v1"):
|
|
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
|
|
if api_key != os.getenv("API_KEY"):
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"error": "Invalid authentication credentials"}
|
|
)
|
|
return await call_next(request)
|
|
|
|
class BaseModelBackend:
|
|
def __init__(self, model_config):
|
|
self.config = model_config
|
|
|
|
async def generate(self, request: ChatCompletionRequest):
|
|
raise NotImplementedError
|
|
|
|
class TestBackend(BaseModelBackend):
|
|
async def generate(self, request: ChatCompletionRequest):
|
|
if request.stream:
|
|
async def chunk_generator():
|
|
content_parts = ["🤣", "👉🏻", "🤡"]
|
|
messages=[m.model_dump() for m in request.messages]
|
|
print(f"messages:_____________{messages}______________")
|
|
for i, part in enumerate(content_parts):
|
|
yield {
|
|
"id": f"chatcmpl-{uuid.uuid4()}",
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": request.model,
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {
|
|
"content": part,
|
|
"role": "assistant" if i == 0 else None,
|
|
"function_call": None,
|
|
"tool_calls": None
|
|
},
|
|
"logprobs": None,
|
|
"finish_reason": "stop" if i == len(content_parts)-1 else None
|
|
}],
|
|
"service_tier": None,
|
|
"system_fingerprint": None,
|
|
"usage": None
|
|
}
|
|
return chunk_generator()
|
|
else:
|
|
return {
|
|
"id": f"chatcmpl-{uuid.uuid4()}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": request.model,
|
|
"choices": [{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "🤣👉🏻🤡",
|
|
"function_call": None,
|
|
"tool_calls": None
|
|
},
|
|
"finish_reason": "stop",
|
|
"index": 0
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 20,
|
|
"total_tokens": 30
|
|
}
|
|
}
|
|
|
|
class OpenAIProxyBackend(BaseModelBackend):
|
|
async def generate(self, request: ChatCompletionRequest):
|
|
from openai import AsyncOpenAI
|
|
|
|
client = AsyncOpenAI(
|
|
api_key=self.config["api_key"],
|
|
base_url=self.config["base_url"]
|
|
)
|
|
|
|
response = await client.chat.completions.create(
|
|
model=self.config["model"],
|
|
messages=[m.model_dump() for m in request.messages],
|
|
temperature=request.temperature,
|
|
max_tokens=request.max_tokens,
|
|
stream=request.stream
|
|
)
|
|
|
|
if request.stream:
|
|
async def async_wrapper():
|
|
async for chunk in response:
|
|
yield chunk
|
|
return async_wrapper()
|
|
return response
|
|
|
|
class LlmClientBackend(BaseModelBackend):
|
|
MAX_CONTEXT_LENGTH = 500
|
|
POOL_SIZE = 2
|
|
|
|
def __init__(self, model_config):
|
|
super().__init__(model_config)
|
|
self._client_pool = []
|
|
self._active_clients = {}
|
|
self._pool_lock = asyncio.Lock()
|
|
self.logger = logging.getLogger("api.client")
|
|
self._inference_executor = ThreadPoolExecutor(max_workers=self.POOL_SIZE)
|
|
self._active_tasks = weakref.WeakSet()
|
|
|
|
async def _parse_content(self, content: Union[str, List[ContentItem]], base64_images: list) -> str:
|
|
text_parts = []
|
|
|
|
if isinstance(content, list):
|
|
for item in content:
|
|
if item.type == "text" and item.text:
|
|
text_parts.append(item.text.strip())
|
|
elif item.type == "image_url" and item.image_url:
|
|
url = item.image_url.get("url", "")
|
|
if url.startswith("data:image"):
|
|
base64_data = url.split(",", 1)[1]
|
|
base64_images.append(base64_data)
|
|
else:
|
|
base64_str = await self.download_image(url)
|
|
if base64_str:
|
|
base64_images.append(base64_str)
|
|
else:
|
|
text_parts.append(str(content).strip())
|
|
|
|
return " ".join(text_parts).strip()
|
|
|
|
async def _get_client(self, request):
|
|
try:
|
|
await asyncio.wait_for(self._pool_lock.acquire(), timeout=30.0)
|
|
|
|
if self._client_pool:
|
|
client = self._client_pool.pop()
|
|
self.logger.debug(f"♻️ Reusing client from pool | ID:{id(client)}")
|
|
return client
|
|
|
|
if len(self._active_clients) >= self.POOL_SIZE:
|
|
raise RuntimeError("Connection pool exhausted")
|
|
|
|
self.logger.debug("Creating new LLM client")
|
|
client = LLMClient(
|
|
host=self.config["host"],
|
|
port=self.config["port"]
|
|
)
|
|
self._active_clients[id(client)] = client
|
|
|
|
system_content = next(
|
|
(m.content for m in request.messages if m.role == "system"),
|
|
self.config.get("system_prompt", "You are a helpful assistant")
|
|
)
|
|
parsed_prompt = await self._parse_content(system_content, [])
|
|
|
|
loop = asyncio.get_event_loop()
|
|
await loop.run_in_executor(
|
|
None,
|
|
lambda: client.setup(
|
|
self.config["object"],
|
|
{
|
|
"model": self.config["model_name"],
|
|
"response_format": self.config["response_format"],
|
|
"input": self.config["input"],
|
|
"enoutput": True,
|
|
"max_token_len": request.max_tokens,
|
|
"temperature": request.temperature,
|
|
"top_p": request.top_p,
|
|
"prompt": parsed_prompt
|
|
}
|
|
)
|
|
)
|
|
return client
|
|
finally:
|
|
self._pool_lock.release()
|
|
|
|
async def _release_client(self, client):
|
|
async with self._pool_lock:
|
|
self._client_pool.append(client)
|
|
self.logger.debug(f"Returned client to pool | ID:{id(client)}")
|
|
|
|
async def inference_stream(self, query: str, base64_images: list, request: ChatCompletionRequest):
|
|
client = await self._get_client(request)
|
|
task = asyncio.current_task()
|
|
self._active_tasks.add(task)
|
|
try:
|
|
self.logger.debug(f"Starting inference | ClientID:{id(client)} Query length:{len(query)}")
|
|
|
|
loop = asyncio.get_event_loop()
|
|
message = client.send_jpeg(base64_images[0], object_type="vlm.jpeg.base64")
|
|
print(f"jpeg date:{message}")
|
|
sync_gen = client.inference_stream(
|
|
query,
|
|
object_type="llm.utf-8"
|
|
)
|
|
|
|
while True:
|
|
if task.cancelled():
|
|
client.stop_inference()
|
|
break
|
|
|
|
def get_next():
|
|
try:
|
|
return next(sync_gen)
|
|
except StopIteration:
|
|
return None
|
|
|
|
chunk = await loop.run_in_executor(
|
|
self._inference_executor,
|
|
get_next
|
|
)
|
|
if chunk is None:
|
|
break
|
|
yield chunk
|
|
except asyncio.CancelledError:
|
|
self.logger.warning("Inference task cancelled, stopping...")
|
|
client.stop_inference()
|
|
raise
|
|
except Exception as e:
|
|
self.logger.error(f"Inference error: {str(e)}")
|
|
yield f"[ERROR: {str(e)}]"
|
|
finally:
|
|
self._active_tasks.discard(task)
|
|
await self._release_client(client)
|
|
self.logger.debug(f"Inference stopped | ClientID:{id(client)}")
|
|
|
|
def _truncate_history(self, messages: List[Message]) -> List[Message]:
|
|
"""Truncate history to fit model context window"""
|
|
total_length = 0
|
|
keep_messages = []
|
|
|
|
# Process in reverse to keep latest messages
|
|
for msg in reversed(messages):
|
|
if msg.role == "system": # Always keep system messages
|
|
keep_messages.insert(0, msg)
|
|
continue
|
|
|
|
msg_length = len(msg.content)
|
|
if total_length + msg_length > self.MAX_CONTEXT_LENGTH:
|
|
break
|
|
total_length += msg_length
|
|
keep_messages.insert(0, msg) # Maintain original order
|
|
|
|
return keep_messages
|
|
|
|
async def download_image(self, url):
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url) as response:
|
|
if response.status == 200:
|
|
image_data = await response.read()
|
|
return base64.b64encode(image_data).decode('utf-8')
|
|
self.logger.error(f"图片下载失败,状态码:{response.status}")
|
|
return None
|
|
except Exception as e:
|
|
self.logger.error(f"图片下载异常:{str(e)}")
|
|
return None
|
|
|
|
async def generate(self, request: ChatCompletionRequest):
|
|
try:
|
|
truncated_messages = self._truncate_history(request.messages)
|
|
|
|
query_lines = []
|
|
base64_images = []
|
|
system_prompt = ""
|
|
|
|
for m in truncated_messages:
|
|
if m.role == "system":
|
|
system_content = await self._parse_content(m.content, base64_images)
|
|
system_prompt += f"{system_content}\n"
|
|
continue
|
|
|
|
message_content = await self._parse_content(m.content, base64_images)
|
|
if message_content:
|
|
query_lines.append(f"{m.role}: {message_content}")
|
|
|
|
final_query = []
|
|
if system_prompt:
|
|
final_query.append(system_prompt.strip())
|
|
if base64_images:
|
|
pass
|
|
# final_query.append("\n".join([f"[IMAGE:{img[:20]}...]" for img in base64_images]))
|
|
final_query.append("\n".join(query_lines))
|
|
|
|
query = "\n\n".join(filter(None, final_query))
|
|
|
|
self.logger.debug(
|
|
f"Processed query | System prompt: {len(system_prompt)} chars | "
|
|
f"Images: {len(base64_images)} | Dialogue lines: {len(query_lines)}"
|
|
)
|
|
|
|
if request.stream:
|
|
async def chunk_generator():
|
|
try:
|
|
async for chunk in self.inference_stream(query, base64_images, request):
|
|
yield {
|
|
"id": f"chatcmpl-{uuid.uuid4()}",
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": request.model,
|
|
"choices": [{
|
|
"index": 0,
|
|
"delta": {"content": chunk},
|
|
"finish_reason": None
|
|
}]
|
|
}
|
|
# Add normal completion marker
|
|
yield {
|
|
"choices": [{
|
|
"delta": {},
|
|
"finish_reason": "stop"
|
|
}]
|
|
}
|
|
except Exception as e:
|
|
self.logger.error(f"Stream generation error: {str(e)}")
|
|
yield {
|
|
"error": {
|
|
"message": f"Stream generation failed: {str(e)}",
|
|
"type": "api_error"
|
|
}
|
|
}
|
|
yield {"choices": [{"delta": {}, "finish_reason": "stop"}]}
|
|
raise
|
|
return chunk_generator()
|
|
else:
|
|
full_response = ""
|
|
async for chunk in self.inference_stream(query, base64_images, request):
|
|
full_response += chunk
|
|
return {
|
|
"id": f"chatcmpl-{uuid.uuid4()}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": request.model,
|
|
"choices": [{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": full_response
|
|
}
|
|
}]
|
|
}
|
|
except RuntimeError as e:
|
|
self.logger.error(f"Connection error: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Model service connection failed: {str(e)}"
|
|
)
|
|
|
|
class ModelDispatcher:
|
|
def __init__(self):
|
|
self.backends = {}
|
|
self.load_models()
|
|
|
|
def load_models(self):
|
|
for model_name, model_config in config.data["models"].items():
|
|
if model_config["type"] == "openai_proxy":
|
|
self.backends[model_name] = OpenAIProxyBackend(model_config)
|
|
elif model_config["type"] == "tcp_client":
|
|
self.backends[model_name] = LlmClientBackend(model_config)
|
|
elif model_config["type"] == "llama.cpp":
|
|
self.backends[model_name] = TestBackend(model_config)
|
|
|
|
def get_backend(self, model_name):
|
|
return self.backends.get(model_name)
|
|
|
|
_dispatcher = ModelDispatcher()
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def chat_completions(request: Request, body: ChatCompletionRequest):
|
|
backend = _dispatcher.get_backend(body.model)
|
|
if not backend:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported model: {body.model}"
|
|
)
|
|
|
|
try:
|
|
print(f"Received request: {body.model_dump()}")
|
|
|
|
if body.stream:
|
|
chunk_generator = await backend.generate(body)
|
|
if not chunk_generator:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Failed to generate stream response"
|
|
)
|
|
|
|
async def format_stream():
|
|
try:
|
|
async for chunk in chunk_generator:
|
|
if isinstance(chunk, dict):
|
|
chunk_dict = chunk
|
|
else:
|
|
chunk_dict = chunk.model_dump()
|
|
|
|
json_chunk = json.dumps(chunk_dict, ensure_ascii=False)
|
|
print(f"Sending chunk: {json_chunk}")
|
|
yield f"data: {json_chunk}\n\n"
|
|
except asyncio.CancelledError:
|
|
logger.warning("客户端提前断开连接,正在终止推理...")
|
|
if backend and isinstance(backend, LlmClientBackend):
|
|
for task in backend._active_tasks:
|
|
task.cancel()
|
|
raise
|
|
finally:
|
|
logger.debug("流连接已关闭")
|
|
|
|
return StreamingResponse(
|
|
format_stream(),
|
|
media_type="text/event-stream"
|
|
)
|
|
else:
|
|
response = await backend.generate(body)
|
|
print(f"Sending response: {response}")
|
|
return JSONResponse(content=response)
|
|
|
|
except HTTPException as he:
|
|
raise he
|
|
except Exception as e:
|
|
logger.error(f"Processing error: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/v1/completions")
|
|
async def create_completion(request: Request, body: CompletionRequest):
|
|
chat_request = ChatCompletionRequest(
|
|
model=body.model,
|
|
messages=[Message(role="user", content=body.prompt)],
|
|
temperature=body.temperature,
|
|
max_tokens=body.max_tokens,
|
|
top_p=body.top_p,
|
|
stream=body.stream
|
|
)
|
|
|
|
backend = _dispatcher.get_backend(chat_request.model)
|
|
if not backend:
|
|
raise HTTPException(status_code=400, detail=f"Unsupported model: {chat_request.model}")
|
|
|
|
try:
|
|
if body.stream:
|
|
chunk_generator = await backend.generate(chat_request)
|
|
|
|
async def convert_stream():
|
|
async for chunk in chunk_generator:
|
|
# 转换格式后需要序列化为JSON字符串
|
|
completion_chunk = {
|
|
"id": chunk.get("id", f"cmpl-{uuid.uuid4()}"),
|
|
"object": "text_completion.chunk",
|
|
"created": chunk.get("created", int(time.time())),
|
|
"model": chat_request.model,
|
|
"choices": [{
|
|
"text": chunk["choices"][0]["delta"].get("content", ""),
|
|
"index": 0,
|
|
"logprobs": None,
|
|
"finish_reason": chunk["choices"][0].get("finish_reason")
|
|
}]
|
|
}
|
|
# 添加SSE格式包装
|
|
yield f"data: {json.dumps(completion_chunk)}\n\n"
|
|
|
|
# 添加流结束标记
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(
|
|
convert_stream(),
|
|
media_type="text/event-stream"
|
|
)
|
|
else:
|
|
chat_response = await backend.generate(chat_request)
|
|
return JSONResponse({
|
|
"id": f"cmpl-{uuid.uuid4()}",
|
|
"object": "text_completion",
|
|
"created": int(time.time()),
|
|
"model": chat_request.model,
|
|
"choices": [{
|
|
"text": chat_response["choices"][0]["message"]["content"],
|
|
"index": 0,
|
|
"logprobs": None,
|
|
"finish_reason": "stop"
|
|
}],
|
|
"usage": chat_response.get("usage", {
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
"total_tokens": 0
|
|
})
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Completion error: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
logging.getLogger().handlers[0].flush() |