initial commit

This commit is contained in:
LittleMouse
2025-02-13 18:10:46 +08:00
commit 2aede93e59
2 changed files with 511 additions and 0 deletions
+390
View File
@@ -0,0 +1,390 @@
import os
import uuid
import yaml
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from typing import Optional, List
import logging
from slowapi import Limiter
from slowapi.util import get_remote_address
import time
import json
import asyncio
from aiostream import stream
from llm_client import LLMClient
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
]
)
logger = logging.getLogger("api")
app = FastAPI(title="OpenAI Compatible API Server")
limiter = Limiter(key_func=get_remote_address)
class Config:
def __init__(self):
with open("config.yaml") as f:
self.data = yaml.safe_load(f)
config = Config()
class Message(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 1000
stream: Optional[bool] = False
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
if request.url.path.startswith("/v1"):
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
if api_key != os.getenv("API_KEY"):
return JSONResponse(
status_code=401,
content={"error": "Invalid authentication credentials"}
)
return await call_next(request)
class BaseModelBackend:
def __init__(self, model_config):
self.config = model_config
async def generate(self, request: ChatCompletionRequest):
raise NotImplementedError
class TestBackend(BaseModelBackend):
async def generate(self, request: ChatCompletionRequest):
if request.stream:
async def chunk_generator():
content_parts = ["🤣", "👉🏻", "🤡"]
messages=[m.model_dump() for m in request.messages]
print(f"messages:_____________{messages}______________")
for i, part in enumerate(content_parts):
yield {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request.model,
"choices": [{
"index": 0,
"delta": {
"content": part,
"role": "assistant" if i == 0 else None,
"function_call": None,
"tool_calls": None
},
"logprobs": None,
"finish_reason": "stop" if i == len(content_parts)-1 else None
}],
"service_tier": None,
"system_fingerprint": None,
"usage": None
}
return chunk_generator()
else:
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [{
"message": {
"role": "assistant",
"content": "🤣👉🏻🤡",
"function_call": None,
"tool_calls": None
},
"finish_reason": "stop",
"index": 0
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}
class OpenAIProxyBackend(BaseModelBackend):
async def generate(self, request: ChatCompletionRequest):
from openai import AsyncOpenAI
client = AsyncOpenAI(
api_key=self.config["api_key"],
base_url=self.config["base_url"]
)
response = await client.chat.completions.create(
model=self.config["model"],
messages=[m.model_dump() for m in request.messages],
temperature=request.temperature,
max_tokens=request.max_tokens,
stream=request.stream
)
if request.stream:
async def async_wrapper():
async for chunk in response:
yield chunk
return async_wrapper()
return response
class LlmClientBackend(BaseModelBackend):
MAX_CONTEXT_LENGTH = 500
POOL_SIZE = 2 # 新增连接池大小限制
def __init__(self, model_config):
super().__init__(model_config)
self._client_pool = [] # 可用连接池
self._active_clients = {} # 使用中的连接
self._pool_lock = asyncio.Lock()
self.logger = logging.getLogger("api.client")
async def _get_client(self, request):
async with self._pool_lock:
# 尝试从池中获取可用连接
if self._client_pool:
client = self._client_pool.pop()
self.logger.debug(f"♻️ Reusing client from pool | ID:{id(client)}")
return client
# 检查是否达到最大连接数
if len(self._active_clients) >= self.POOL_SIZE:
raise RuntimeError("Connection pool exhausted")
# 创建新连接
self.logger.debug("🆕 Creating new LLM client")
client = LLMClient(
host=self.config["host"],
port=self.config["port"]
)
self._active_clients[id(client)] = client
# 初始化连接
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
client.setup,
{
"model": self.config["model_name"],
"response_format": "llm.utf-8.stream",
"input": "llm.utf-8",
"enoutput": True,
"max_token_len": request.max_tokens,
"temperature": request.temperature,
"prompt": next(
(m.content for m in request.messages if m.role == "system"),
self.config.get("system_prompt", "You are a helpful assistant")
)
}
)
return client
async def _release_client(self, client):
async with self._pool_lock:
# 将连接放回池中供后续使用
self._client_pool.append(client)
self.logger.debug(f"🔙 Returned client to pool | ID:{id(client)}")
async def inference_stream(self, query: str, request: ChatCompletionRequest):
client = await self._get_client(request)
try:
self.logger.debug(f"📡 Starting inference | ClientID:{id(client)} Query length:{len(query)}")
loop = asyncio.get_event_loop()
sync_gen = client.inference_stream(query)
while True:
try:
# 使用闭包捕获生成器状态
def get_next():
try:
return next(sync_gen)
except StopIteration:
return None # 返回哨兵值代替抛出异常
chunk = await loop.run_in_executor(None, get_next)
if chunk is None: # 检测到生成器结束
break
yield chunk
except Exception as e:
self.logger.error(f"Inference error: {str(e)}")
yield f"[ERROR: {str(e)}]"
break
finally:
await self._release_client(client)
def _truncate_history(self, messages: List[Message]) -> List[Message]:
"""Truncate history to fit model context window"""
total_length = 0
keep_messages = []
# Process in reverse to keep latest messages
for msg in reversed(messages):
if msg.role == "system": # Always keep system messages
keep_messages.insert(0, msg)
continue
msg_length = len(msg.content)
if total_length + msg_length > self.MAX_CONTEXT_LENGTH:
break
total_length += msg_length
keep_messages.insert(0, msg) # Maintain original order
return keep_messages
async def generate(self, request: ChatCompletionRequest):
try:
truncated_messages = self._truncate_history(request.messages)
query = "\n".join([
f"{m.role}: {m.content}"
for m in truncated_messages
if m.role != "system"
])
self.logger.debug(
f"Context truncated: Original {len(request.messages)} → Kept {len(truncated_messages)} "
f"Total length:{len(query)} chars"
)
if request.stream:
async def chunk_generator():
try:
async for chunk in self.inference_stream(query, request):
yield {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request.model,
"choices": [{
"index": 0,
"delta": {"content": chunk},
"finish_reason": None
}]
}
# Add normal completion marker
yield {
"choices": [{
"delta": {},
"finish_reason": "stop"
}]
}
except Exception as e:
self.logger.error(f"Stream generation error: {str(e)}")
yield {
"error": {
"message": f"Stream generation failed: {str(e)}",
"type": "api_error"
}
}
yield {"choices": [{"delta": {}, "finish_reason": "stop"}]}
raise
return chunk_generator()
else:
full_response = ""
async for chunk in self.inference_stream(query, request):
full_response += chunk
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [{
"message": {
"role": "assistant",
"content": full_response
}
}]
}
except RuntimeError as e:
self.logger.error(f"Connection error: {str(e)}")
raise HTTPException(
status_code=400,
detail=f"Model service connection failed: {str(e)}"
)
class ModelDispatcher:
def __init__(self):
self.backends = {}
self.load_models()
def load_models(self):
for model_name, model_config in config.data["models"].items():
if model_config["type"] == "openai_proxy":
self.backends[model_name] = OpenAIProxyBackend(model_config)
elif model_config["type"] == "tcp_client":
self.backends[model_name] = LlmClientBackend(model_config)
elif model_config["type"] == "llama.cpp":
self.backends[model_name] = TestBackend(model_config)
def get_backend(self, model_name):
return self.backends.get(model_name)
_dispatcher = ModelDispatcher()
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, body: ChatCompletionRequest):
backend = _dispatcher.get_backend(body.model)
if not backend:
raise HTTPException(
status_code=400,
detail=f"Unsupported model: {body.model}"
)
try:
print(f"Received request: {body.model_dump()}")
if body.stream:
chunk_generator = await backend.generate(body)
if not chunk_generator:
raise HTTPException(
status_code=500,
detail="Failed to generate stream response"
)
async def format_stream():
try:
async for chunk in chunk_generator:
if isinstance(chunk, dict):
chunk_dict = chunk
else:
chunk_dict = chunk.model_dump()
json_chunk = json.dumps(chunk_dict, ensure_ascii=False)
print(f"Sending chunk: {json_chunk}")
yield f"data: {json_chunk}\n\n"
except Exception as e:
logger.error(f"Stream interrupted: {str(e)}")
yield f"data: {{'error': 'Stream interrupted'}}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
format_stream(),
media_type="text/event-stream"
)
else:
response = await backend.generate(body)
print(f"Sending response: {response}")
return JSONResponse(content=response)
except HTTPException as he:
raise he
except Exception as e:
logger.error(f"Processing error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
logging.getLogger().handlers[0].flush()
+121
View File
@@ -0,0 +1,121 @@
import json
import socket
import time
import uuid
from contextlib import contextmanager
from typing import Generator
import logging
import threading
logger = logging.getLogger("llm_client")
logger.setLevel(logging.DEBUG) # 根据需要调整级别
class LLMClient:
def __init__(self, host: str = "localhost", port: int = 10001):
self._lock = threading.Lock() # 添加线程锁
self.host = host
self.port = port
self.sock = None
self.work_id = None # 保存服务端返回的work_id
self._initialized = False # 新增初始化状态标记
self._connect() # 添加连接方法
def __enter__(self):
self.connect()
return self
def __exit__(self, _exc_type, _exc_val, _exc_tb):
self.close()
def _connect(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
self.sock.connect((self.host, self.port))
except ConnectionRefusedError as e:
raise RuntimeError(f"无法连接到 {self.host}:{self.port}") from e
def close(self):
if self.sock:
self.sock.close()
self.sock = None
def _send_request(self, action: str, data: dict) -> str:
"""通用请求发送方法"""
request_id = str(uuid.uuid4())
payload = {
"request_id": request_id,
"work_id": self.work_id or "llm",
"action": action,
"object": "llm.setup" if action == "setup" else "llm.utf-8",
"data": data
}
logger.debug(
f"Sending request: [ID:{request_id}] "
f"Action:{action} WorkID:{payload['work_id']}\n"
f"Data: {str(data)[:100]}..."
)
self.sock.sendall(json.dumps(payload, ensure_ascii=False).encode('utf-8'))
return request_id
def setup(self, model_config: dict) -> dict:
if not self.sock:
self._connect()
request_id = self._send_request("setup", model_config)
return self._wait_response(request_id)
def inference_stream(self, query: str) -> Generator[str, None, None]:
request_id = self._send_request("inference", query)
while True:
response = json.loads(self.sock.recv(4096).decode())
if response["request_id"] != request_id:
continue
yield response["data"]["delta"]
if response["data"].get("finish", False):
self.work_id = response["work_id"]
break
def exit(self) -> dict:
request_id = self._send_request("exit", {})
result = self._wait_response(request_id)
self._initialized = False
return result
def _wait_response(self, request_id: str) -> dict:
start_time = time.time()
while time.time() - start_time < 10:
response = json.loads(self.sock.recv(4096).decode())
if response["request_id"] == request_id:
if response["error"]["code"] != 0:
raise RuntimeError(f"Server error: {response['error']['message']}")
self.work_id = response["work_id"]
return response
raise TimeoutError("No response from server")
def connect(self):
"""显式连接方法"""
with self._lock:
if not self.sock:
self._connect()
# 使用示例
if __name__ == "__main__":
with LLMClient(host='192.168.20.183') as client:
setup_response = client.setup({
"model": "deepseek-r1-1.5B-ax630c",
"response_format": "llm.utf-8.stream",
"input": "llm.utf-8",
"enoutput": True,
"max_token_len": 256,
"prompt": "You are a helpful assistant"
})
print("Setup response:", setup_response)
for chunk in client.inference_stream("What's your name?"):
print("Received chunk:", chunk)
exit_response = client.exit()
print("Exit response:", exit_response)