From 2aede93e598b7d010474bd65125baac2ce023c17 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Thu, 13 Feb 2025 18:10:46 +0800 Subject: [PATCH] initial commit --- api_server.py | 390 ++++++++++++++++++++++++++++++++++++++++++++++++++ llm_client.py | 121 ++++++++++++++++ 2 files changed, 511 insertions(+) create mode 100644 api_server.py create mode 100644 llm_client.py diff --git a/api_server.py b/api_server.py new file mode 100644 index 0000000..928eeb5 --- /dev/null +++ b/api_server.py @@ -0,0 +1,390 @@ +import os +import uuid +import yaml +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel +from typing import Optional, List +import logging +from slowapi import Limiter +from slowapi.util import get_remote_address +import time +import json +import asyncio +from aiostream import stream +from llm_client import LLMClient + +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(), + ] +) +logger = logging.getLogger("api") + +app = FastAPI(title="OpenAI Compatible API Server") +limiter = Limiter(key_func=get_remote_address) + +class Config: + def __init__(self): + with open("config.yaml") as f: + self.data = yaml.safe_load(f) + +config = Config() + +class Message(BaseModel): + role: str + content: str + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[Message] + temperature: Optional[float] = 0.7 + max_tokens: Optional[int] = 1000 + stream: Optional[bool] = False + +@app.middleware("http") +async def auth_middleware(request: Request, call_next): + if request.url.path.startswith("/v1"): + api_key = request.headers.get("Authorization", "").replace("Bearer ", "") + if api_key != os.getenv("API_KEY"): + return JSONResponse( + status_code=401, + content={"error": "Invalid authentication credentials"} + ) + return await call_next(request) + +class BaseModelBackend: + def __init__(self, model_config): + self.config = model_config + + async def generate(self, request: ChatCompletionRequest): + raise NotImplementedError + +class TestBackend(BaseModelBackend): + async def generate(self, request: ChatCompletionRequest): + if request.stream: + async def chunk_generator(): + content_parts = ["🤣", "👉🏻", "🤡"] + messages=[m.model_dump() for m in request.messages] + print(f"messages:_____________{messages}______________") + for i, part in enumerate(content_parts): + yield { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": request.model, + "choices": [{ + "index": 0, + "delta": { + "content": part, + "role": "assistant" if i == 0 else None, + "function_call": None, + "tool_calls": None + }, + "logprobs": None, + "finish_reason": "stop" if i == len(content_parts)-1 else None + }], + "service_tier": None, + "system_fingerprint": None, + "usage": None + } + return chunk_generator() + else: + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": request.model, + "choices": [{ + "message": { + "role": "assistant", + "content": "🤣👉🏻🤡", + "function_call": None, + "tool_calls": None + }, + "finish_reason": "stop", + "index": 0 + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30 + } + } + +class OpenAIProxyBackend(BaseModelBackend): + async def generate(self, request: ChatCompletionRequest): + from openai import AsyncOpenAI + + client = AsyncOpenAI( + api_key=self.config["api_key"], + base_url=self.config["base_url"] + ) + + response = await client.chat.completions.create( + model=self.config["model"], + messages=[m.model_dump() for m in request.messages], + temperature=request.temperature, + max_tokens=request.max_tokens, + stream=request.stream + ) + + if request.stream: + async def async_wrapper(): + async for chunk in response: + yield chunk + return async_wrapper() + return response + +class LlmClientBackend(BaseModelBackend): + MAX_CONTEXT_LENGTH = 500 + POOL_SIZE = 2 # 新增连接池大小限制 + + def __init__(self, model_config): + super().__init__(model_config) + self._client_pool = [] # 可用连接池 + self._active_clients = {} # 使用中的连接 + self._pool_lock = asyncio.Lock() + self.logger = logging.getLogger("api.client") + + async def _get_client(self, request): + async with self._pool_lock: + # 尝试从池中获取可用连接 + if self._client_pool: + client = self._client_pool.pop() + self.logger.debug(f"♻️ Reusing client from pool | ID:{id(client)}") + return client + + # 检查是否达到最大连接数 + if len(self._active_clients) >= self.POOL_SIZE: + raise RuntimeError("Connection pool exhausted") + + # 创建新连接 + self.logger.debug("🆕 Creating new LLM client") + client = LLMClient( + host=self.config["host"], + port=self.config["port"] + ) + self._active_clients[id(client)] = client + + # 初始化连接 + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, + client.setup, + { + "model": self.config["model_name"], + "response_format": "llm.utf-8.stream", + "input": "llm.utf-8", + "enoutput": True, + "max_token_len": request.max_tokens, + "temperature": request.temperature, + "prompt": next( + (m.content for m in request.messages if m.role == "system"), + self.config.get("system_prompt", "You are a helpful assistant") + ) + } + ) + return client + + async def _release_client(self, client): + async with self._pool_lock: + # 将连接放回池中供后续使用 + self._client_pool.append(client) + self.logger.debug(f"🔙 Returned client to pool | ID:{id(client)}") + + async def inference_stream(self, query: str, request: ChatCompletionRequest): + client = await self._get_client(request) + try: + self.logger.debug(f"📡 Starting inference | ClientID:{id(client)} Query length:{len(query)}") + + loop = asyncio.get_event_loop() + sync_gen = client.inference_stream(query) + + while True: + try: + # 使用闭包捕获生成器状态 + def get_next(): + try: + return next(sync_gen) + except StopIteration: + return None # 返回哨兵值代替抛出异常 + + chunk = await loop.run_in_executor(None, get_next) + if chunk is None: # 检测到生成器结束 + break + yield chunk + except Exception as e: + self.logger.error(f"Inference error: {str(e)}") + yield f"[ERROR: {str(e)}]" + break + finally: + await self._release_client(client) + + def _truncate_history(self, messages: List[Message]) -> List[Message]: + """Truncate history to fit model context window""" + total_length = 0 + keep_messages = [] + + # Process in reverse to keep latest messages + for msg in reversed(messages): + if msg.role == "system": # Always keep system messages + keep_messages.insert(0, msg) + continue + + msg_length = len(msg.content) + if total_length + msg_length > self.MAX_CONTEXT_LENGTH: + break + total_length += msg_length + keep_messages.insert(0, msg) # Maintain original order + + return keep_messages + + async def generate(self, request: ChatCompletionRequest): + try: + truncated_messages = self._truncate_history(request.messages) + + query = "\n".join([ + f"{m.role}: {m.content}" + for m in truncated_messages + if m.role != "system" + ]) + + self.logger.debug( + f"Context truncated: Original {len(request.messages)} → Kept {len(truncated_messages)} " + f"Total length:{len(query)} chars" + ) + + if request.stream: + async def chunk_generator(): + try: + async for chunk in self.inference_stream(query, request): + yield { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": request.model, + "choices": [{ + "index": 0, + "delta": {"content": chunk}, + "finish_reason": None + }] + } + # Add normal completion marker + yield { + "choices": [{ + "delta": {}, + "finish_reason": "stop" + }] + } + except Exception as e: + self.logger.error(f"Stream generation error: {str(e)}") + yield { + "error": { + "message": f"Stream generation failed: {str(e)}", + "type": "api_error" + } + } + yield {"choices": [{"delta": {}, "finish_reason": "stop"}]} + raise + return chunk_generator() + else: + full_response = "" + async for chunk in self.inference_stream(query, request): + full_response += chunk + return { + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": request.model, + "choices": [{ + "message": { + "role": "assistant", + "content": full_response + } + }] + } + except RuntimeError as e: + self.logger.error(f"Connection error: {str(e)}") + raise HTTPException( + status_code=400, + detail=f"Model service connection failed: {str(e)}" + ) + +class ModelDispatcher: + def __init__(self): + self.backends = {} + self.load_models() + + def load_models(self): + for model_name, model_config in config.data["models"].items(): + if model_config["type"] == "openai_proxy": + self.backends[model_name] = OpenAIProxyBackend(model_config) + elif model_config["type"] == "tcp_client": + self.backends[model_name] = LlmClientBackend(model_config) + elif model_config["type"] == "llama.cpp": + self.backends[model_name] = TestBackend(model_config) + + def get_backend(self, model_name): + return self.backends.get(model_name) + +_dispatcher = ModelDispatcher() + +@app.post("/v1/chat/completions") +async def chat_completions(request: Request, body: ChatCompletionRequest): + backend = _dispatcher.get_backend(body.model) + if not backend: + raise HTTPException( + status_code=400, + detail=f"Unsupported model: {body.model}" + ) + + try: + print(f"Received request: {body.model_dump()}") + + if body.stream: + chunk_generator = await backend.generate(body) + if not chunk_generator: + raise HTTPException( + status_code=500, + detail="Failed to generate stream response" + ) + + async def format_stream(): + try: + async for chunk in chunk_generator: + if isinstance(chunk, dict): + chunk_dict = chunk + else: + chunk_dict = chunk.model_dump() + + json_chunk = json.dumps(chunk_dict, ensure_ascii=False) + print(f"Sending chunk: {json_chunk}") + yield f"data: {json_chunk}\n\n" + except Exception as e: + logger.error(f"Stream interrupted: {str(e)}") + yield f"data: {{'error': 'Stream interrupted'}}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse( + format_stream(), + media_type="text/event-stream" + ) + else: + response = await backend.generate(body) + print(f"Sending response: {response}") + return JSONResponse(content=response) + + except HTTPException as he: + raise he + except Exception as e: + logger.error(f"Processing error: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) + +logging.getLogger().handlers[0].flush() \ No newline at end of file diff --git a/llm_client.py b/llm_client.py new file mode 100644 index 0000000..58a0b39 --- /dev/null +++ b/llm_client.py @@ -0,0 +1,121 @@ +import json +import socket +import time +import uuid +from contextlib import contextmanager +from typing import Generator +import logging +import threading + +logger = logging.getLogger("llm_client") +logger.setLevel(logging.DEBUG) # 根据需要调整级别 + +class LLMClient: + def __init__(self, host: str = "localhost", port: int = 10001): + self._lock = threading.Lock() # 添加线程锁 + self.host = host + self.port = port + self.sock = None + self.work_id = None # 保存服务端返回的work_id + self._initialized = False # 新增初始化状态标记 + self._connect() # 添加连接方法 + + def __enter__(self): + self.connect() + return self + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.close() + + def _connect(self): + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + self.sock.connect((self.host, self.port)) + except ConnectionRefusedError as e: + raise RuntimeError(f"无法连接到 {self.host}:{self.port}") from e + + def close(self): + if self.sock: + self.sock.close() + self.sock = None + + def _send_request(self, action: str, data: dict) -> str: + """通用请求发送方法""" + request_id = str(uuid.uuid4()) + payload = { + "request_id": request_id, + "work_id": self.work_id or "llm", + "action": action, + "object": "llm.setup" if action == "setup" else "llm.utf-8", + "data": data + } + + logger.debug( + f"Sending request: [ID:{request_id}] " + f"Action:{action} WorkID:{payload['work_id']}\n" + f"Data: {str(data)[:100]}..." + ) + + self.sock.sendall(json.dumps(payload, ensure_ascii=False).encode('utf-8')) + return request_id + + def setup(self, model_config: dict) -> dict: + if not self.sock: + self._connect() + request_id = self._send_request("setup", model_config) + return self._wait_response(request_id) + + def inference_stream(self, query: str) -> Generator[str, None, None]: + request_id = self._send_request("inference", query) + + while True: + response = json.loads(self.sock.recv(4096).decode()) + if response["request_id"] != request_id: + continue + + yield response["data"]["delta"] + if response["data"].get("finish", False): + self.work_id = response["work_id"] + break + + def exit(self) -> dict: + request_id = self._send_request("exit", {}) + result = self._wait_response(request_id) + self._initialized = False + return result + + def _wait_response(self, request_id: str) -> dict: + start_time = time.time() + while time.time() - start_time < 10: + response = json.loads(self.sock.recv(4096).decode()) + if response["request_id"] == request_id: + if response["error"]["code"] != 0: + raise RuntimeError(f"Server error: {response['error']['message']}") + self.work_id = response["work_id"] + return response + raise TimeoutError("No response from server") + + def connect(self): + """显式连接方法""" + with self._lock: + if not self.sock: + self._connect() + +# 使用示例 +if __name__ == "__main__": + with LLMClient(host='192.168.20.183') as client: + setup_response = client.setup({ + "model": "deepseek-r1-1.5B-ax630c", + "response_format": "llm.utf-8.stream", + "input": "llm.utf-8", + "enoutput": True, + "max_token_len": 256, + "prompt": "You are a helpful assistant" + }) + print("Setup response:", setup_response) + + for chunk in client.inference_stream("What's your name?"): + print("Received chunk:", chunk) + + exit_response = client.exit() + print("Exit response:", exit_response)