You've already forked ModuleLLM-OpenAI-Plugin
mirror of
https://github.com/m5stack/ModuleLLM-OpenAI-Plugin.git
synced 2026-05-20 11:37:26 -07:00
initial commit
This commit is contained in:
+390
@@ -0,0 +1,390 @@
|
||||
import os
|
||||
import uuid
|
||||
import yaml
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
import logging
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from aiostream import stream
|
||||
from llm_client import LLMClient
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger("api")
|
||||
|
||||
app = FastAPI(title="OpenAI Compatible API Server")
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
with open("config.yaml") as f:
|
||||
self.data = yaml.safe_load(f)
|
||||
|
||||
config = Config()
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
temperature: Optional[float] = 0.7
|
||||
max_tokens: Optional[int] = 1000
|
||||
stream: Optional[bool] = False
|
||||
|
||||
@app.middleware("http")
|
||||
async def auth_middleware(request: Request, call_next):
|
||||
if request.url.path.startswith("/v1"):
|
||||
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
|
||||
if api_key != os.getenv("API_KEY"):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"error": "Invalid authentication credentials"}
|
||||
)
|
||||
return await call_next(request)
|
||||
|
||||
class BaseModelBackend:
|
||||
def __init__(self, model_config):
|
||||
self.config = model_config
|
||||
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
raise NotImplementedError
|
||||
|
||||
class TestBackend(BaseModelBackend):
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
if request.stream:
|
||||
async def chunk_generator():
|
||||
content_parts = ["🤣", "👉🏻", "🤡"]
|
||||
messages=[m.model_dump() for m in request.messages]
|
||||
print(f"messages:_____________{messages}______________")
|
||||
for i, part in enumerate(content_parts):
|
||||
yield {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": part,
|
||||
"role": "assistant" if i == 0 else None,
|
||||
"function_call": None,
|
||||
"tool_calls": None
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop" if i == len(content_parts)-1 else None
|
||||
}],
|
||||
"service_tier": None,
|
||||
"system_fingerprint": None,
|
||||
"usage": None
|
||||
}
|
||||
return chunk_generator()
|
||||
else:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "🤣👉🏻🤡",
|
||||
"function_call": None,
|
||||
"tool_calls": None
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
|
||||
class OpenAIProxyBackend(BaseModelBackend):
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
client = AsyncOpenAI(
|
||||
api_key=self.config["api_key"],
|
||||
base_url=self.config["base_url"]
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=self.config["model"],
|
||||
messages=[m.model_dump() for m in request.messages],
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
stream=request.stream
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
async def async_wrapper():
|
||||
async for chunk in response:
|
||||
yield chunk
|
||||
return async_wrapper()
|
||||
return response
|
||||
|
||||
class LlmClientBackend(BaseModelBackend):
|
||||
MAX_CONTEXT_LENGTH = 500
|
||||
POOL_SIZE = 2 # 新增连接池大小限制
|
||||
|
||||
def __init__(self, model_config):
|
||||
super().__init__(model_config)
|
||||
self._client_pool = [] # 可用连接池
|
||||
self._active_clients = {} # 使用中的连接
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self.logger = logging.getLogger("api.client")
|
||||
|
||||
async def _get_client(self, request):
|
||||
async with self._pool_lock:
|
||||
# 尝试从池中获取可用连接
|
||||
if self._client_pool:
|
||||
client = self._client_pool.pop()
|
||||
self.logger.debug(f"♻️ Reusing client from pool | ID:{id(client)}")
|
||||
return client
|
||||
|
||||
# 检查是否达到最大连接数
|
||||
if len(self._active_clients) >= self.POOL_SIZE:
|
||||
raise RuntimeError("Connection pool exhausted")
|
||||
|
||||
# 创建新连接
|
||||
self.logger.debug("🆕 Creating new LLM client")
|
||||
client = LLMClient(
|
||||
host=self.config["host"],
|
||||
port=self.config["port"]
|
||||
)
|
||||
self._active_clients[id(client)] = client
|
||||
|
||||
# 初始化连接
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
client.setup,
|
||||
{
|
||||
"model": self.config["model_name"],
|
||||
"response_format": "llm.utf-8.stream",
|
||||
"input": "llm.utf-8",
|
||||
"enoutput": True,
|
||||
"max_token_len": request.max_tokens,
|
||||
"temperature": request.temperature,
|
||||
"prompt": next(
|
||||
(m.content for m in request.messages if m.role == "system"),
|
||||
self.config.get("system_prompt", "You are a helpful assistant")
|
||||
)
|
||||
}
|
||||
)
|
||||
return client
|
||||
|
||||
async def _release_client(self, client):
|
||||
async with self._pool_lock:
|
||||
# 将连接放回池中供后续使用
|
||||
self._client_pool.append(client)
|
||||
self.logger.debug(f"🔙 Returned client to pool | ID:{id(client)}")
|
||||
|
||||
async def inference_stream(self, query: str, request: ChatCompletionRequest):
|
||||
client = await self._get_client(request)
|
||||
try:
|
||||
self.logger.debug(f"📡 Starting inference | ClientID:{id(client)} Query length:{len(query)}")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
sync_gen = client.inference_stream(query)
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 使用闭包捕获生成器状态
|
||||
def get_next():
|
||||
try:
|
||||
return next(sync_gen)
|
||||
except StopIteration:
|
||||
return None # 返回哨兵值代替抛出异常
|
||||
|
||||
chunk = await loop.run_in_executor(None, get_next)
|
||||
if chunk is None: # 检测到生成器结束
|
||||
break
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self.logger.error(f"Inference error: {str(e)}")
|
||||
yield f"[ERROR: {str(e)}]"
|
||||
break
|
||||
finally:
|
||||
await self._release_client(client)
|
||||
|
||||
def _truncate_history(self, messages: List[Message]) -> List[Message]:
|
||||
"""Truncate history to fit model context window"""
|
||||
total_length = 0
|
||||
keep_messages = []
|
||||
|
||||
# Process in reverse to keep latest messages
|
||||
for msg in reversed(messages):
|
||||
if msg.role == "system": # Always keep system messages
|
||||
keep_messages.insert(0, msg)
|
||||
continue
|
||||
|
||||
msg_length = len(msg.content)
|
||||
if total_length + msg_length > self.MAX_CONTEXT_LENGTH:
|
||||
break
|
||||
total_length += msg_length
|
||||
keep_messages.insert(0, msg) # Maintain original order
|
||||
|
||||
return keep_messages
|
||||
|
||||
async def generate(self, request: ChatCompletionRequest):
|
||||
try:
|
||||
truncated_messages = self._truncate_history(request.messages)
|
||||
|
||||
query = "\n".join([
|
||||
f"{m.role}: {m.content}"
|
||||
for m in truncated_messages
|
||||
if m.role != "system"
|
||||
])
|
||||
|
||||
self.logger.debug(
|
||||
f"Context truncated: Original {len(request.messages)} → Kept {len(truncated_messages)} "
|
||||
f"Total length:{len(query)} chars"
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
async def chunk_generator():
|
||||
try:
|
||||
async for chunk in self.inference_stream(query, request):
|
||||
yield {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": chunk},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
# Add normal completion marker
|
||||
yield {
|
||||
"choices": [{
|
||||
"delta": {},
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"Stream generation error: {str(e)}")
|
||||
yield {
|
||||
"error": {
|
||||
"message": f"Stream generation failed: {str(e)}",
|
||||
"type": "api_error"
|
||||
}
|
||||
}
|
||||
yield {"choices": [{"delta": {}, "finish_reason": "stop"}]}
|
||||
raise
|
||||
return chunk_generator()
|
||||
else:
|
||||
full_response = ""
|
||||
async for chunk in self.inference_stream(query, request):
|
||||
full_response += chunk
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": request.model,
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": full_response
|
||||
}
|
||||
}]
|
||||
}
|
||||
except RuntimeError as e:
|
||||
self.logger.error(f"Connection error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Model service connection failed: {str(e)}"
|
||||
)
|
||||
|
||||
class ModelDispatcher:
|
||||
def __init__(self):
|
||||
self.backends = {}
|
||||
self.load_models()
|
||||
|
||||
def load_models(self):
|
||||
for model_name, model_config in config.data["models"].items():
|
||||
if model_config["type"] == "openai_proxy":
|
||||
self.backends[model_name] = OpenAIProxyBackend(model_config)
|
||||
elif model_config["type"] == "tcp_client":
|
||||
self.backends[model_name] = LlmClientBackend(model_config)
|
||||
elif model_config["type"] == "llama.cpp":
|
||||
self.backends[model_name] = TestBackend(model_config)
|
||||
|
||||
def get_backend(self, model_name):
|
||||
return self.backends.get(model_name)
|
||||
|
||||
_dispatcher = ModelDispatcher()
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request, body: ChatCompletionRequest):
|
||||
backend = _dispatcher.get_backend(body.model)
|
||||
if not backend:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported model: {body.model}"
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"Received request: {body.model_dump()}")
|
||||
|
||||
if body.stream:
|
||||
chunk_generator = await backend.generate(body)
|
||||
if not chunk_generator:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to generate stream response"
|
||||
)
|
||||
|
||||
async def format_stream():
|
||||
try:
|
||||
async for chunk in chunk_generator:
|
||||
if isinstance(chunk, dict):
|
||||
chunk_dict = chunk
|
||||
else:
|
||||
chunk_dict = chunk.model_dump()
|
||||
|
||||
json_chunk = json.dumps(chunk_dict, ensure_ascii=False)
|
||||
print(f"Sending chunk: {json_chunk}")
|
||||
yield f"data: {json_chunk}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Stream interrupted: {str(e)}")
|
||||
yield f"data: {{'error': 'Stream interrupted'}}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
format_stream(),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
response = await backend.generate(body)
|
||||
print(f"Sending response: {response}")
|
||||
return JSONResponse(content=response)
|
||||
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except Exception as e:
|
||||
logger.error(f"Processing error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
logging.getLogger().handlers[0].flush()
|
||||
+121
@@ -0,0 +1,121 @@
|
||||
import json
|
||||
import socket
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
import logging
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger("llm_client")
|
||||
logger.setLevel(logging.DEBUG) # 根据需要调整级别
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self, host: str = "localhost", port: int = 10001):
|
||||
self._lock = threading.Lock() # 添加线程锁
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.sock = None
|
||||
self.work_id = None # 保存服务端返回的work_id
|
||||
self._initialized = False # 新增初始化状态标记
|
||||
self._connect() # 添加连接方法
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, _exc_type, _exc_val, _exc_tb):
|
||||
self.close()
|
||||
|
||||
def _connect(self):
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
self.sock.connect((self.host, self.port))
|
||||
except ConnectionRefusedError as e:
|
||||
raise RuntimeError(f"无法连接到 {self.host}:{self.port}") from e
|
||||
|
||||
def close(self):
|
||||
if self.sock:
|
||||
self.sock.close()
|
||||
self.sock = None
|
||||
|
||||
def _send_request(self, action: str, data: dict) -> str:
|
||||
"""通用请求发送方法"""
|
||||
request_id = str(uuid.uuid4())
|
||||
payload = {
|
||||
"request_id": request_id,
|
||||
"work_id": self.work_id or "llm",
|
||||
"action": action,
|
||||
"object": "llm.setup" if action == "setup" else "llm.utf-8",
|
||||
"data": data
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"Sending request: [ID:{request_id}] "
|
||||
f"Action:{action} WorkID:{payload['work_id']}\n"
|
||||
f"Data: {str(data)[:100]}..."
|
||||
)
|
||||
|
||||
self.sock.sendall(json.dumps(payload, ensure_ascii=False).encode('utf-8'))
|
||||
return request_id
|
||||
|
||||
def setup(self, model_config: dict) -> dict:
|
||||
if not self.sock:
|
||||
self._connect()
|
||||
request_id = self._send_request("setup", model_config)
|
||||
return self._wait_response(request_id)
|
||||
|
||||
def inference_stream(self, query: str) -> Generator[str, None, None]:
|
||||
request_id = self._send_request("inference", query)
|
||||
|
||||
while True:
|
||||
response = json.loads(self.sock.recv(4096).decode())
|
||||
if response["request_id"] != request_id:
|
||||
continue
|
||||
|
||||
yield response["data"]["delta"]
|
||||
if response["data"].get("finish", False):
|
||||
self.work_id = response["work_id"]
|
||||
break
|
||||
|
||||
def exit(self) -> dict:
|
||||
request_id = self._send_request("exit", {})
|
||||
result = self._wait_response(request_id)
|
||||
self._initialized = False
|
||||
return result
|
||||
|
||||
def _wait_response(self, request_id: str) -> dict:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < 10:
|
||||
response = json.loads(self.sock.recv(4096).decode())
|
||||
if response["request_id"] == request_id:
|
||||
if response["error"]["code"] != 0:
|
||||
raise RuntimeError(f"Server error: {response['error']['message']}")
|
||||
self.work_id = response["work_id"]
|
||||
return response
|
||||
raise TimeoutError("No response from server")
|
||||
|
||||
def connect(self):
|
||||
"""显式连接方法"""
|
||||
with self._lock:
|
||||
if not self.sock:
|
||||
self._connect()
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
with LLMClient(host='192.168.20.183') as client:
|
||||
setup_response = client.setup({
|
||||
"model": "deepseek-r1-1.5B-ax630c",
|
||||
"response_format": "llm.utf-8.stream",
|
||||
"input": "llm.utf-8",
|
||||
"enoutput": True,
|
||||
"max_token_len": 256,
|
||||
"prompt": "You are a helpful assistant"
|
||||
})
|
||||
print("Setup response:", setup_response)
|
||||
|
||||
for chunk in client.inference_stream("What's your name?"):
|
||||
print("Received chunk:", chunk)
|
||||
|
||||
exit_response = client.exit()
|
||||
print("Exit response:", exit_response)
|
||||
Reference in New Issue
Block a user