Files
ModuleLLM-OpenAI-Plugin/api_server.py
T

331 lines
11 KiB
Python
Raw Normal View History

2025-02-13 18:10:46 +08:00
import os
import uuid
import yaml
2025-02-20 18:43:45 +08:00
from fastapi import FastAPI, Request, HTTPException, File, Form, UploadFile
2025-02-13 18:10:46 +08:00
from fastapi.responses import JSONResponse, StreamingResponse
import logging
from slowapi import Limiter
from slowapi.util import get_remote_address
import time
import json
import asyncio
2025-02-20 18:43:45 +08:00
from backend import (
TestBackend,
OpenAIProxyBackend,
LlmClientBackend,
VisionModelBackend,
ASRClientBackend,
TtsClientBackend,
ChatCompletionRequest,
CompletionRequest,
Message,
)
2025-02-13 18:10:46 +08:00
from services.model_list import GetModelList
2025-02-13 18:10:46 +08:00
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
]
)
logger = logging.getLogger("api")
app = FastAPI(title="OpenAI Compatible API Server")
limiter = Limiter(key_func=get_remote_address)
class Config:
def __init__(self):
2025-02-20 17:18:20 +08:00
with open("config/config.yaml") as f:
2025-02-13 18:10:46 +08:00
self.data = yaml.safe_load(f)
config = Config()
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
if request.url.path.startswith("/v1"):
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
# if api_key != os.getenv("API_KEY"):
# return JSONResponse(
# status_code=401,
# content={"error": "Invalid authentication credentials"}
# )
2025-02-13 18:10:46 +08:00
return await call_next(request)
class ModelDispatcher:
def __init__(self):
self.backends = {}
2025-03-18 18:09:19 +08:00
self.llm_models = []
self.lock = asyncio.Lock()
2025-02-13 18:10:46 +08:00
2025-03-18 18:09:19 +08:00
async def get_backend(self, model_name):
async with self.lock:
if model_name not in self.backends:
model_config = config.data["models"].get(model_name)
if model_config is None:
return None
if model_config["type"] == "openai_proxy":
self.backends[model_name] = OpenAIProxyBackend(model_config)
elif model_config["type"] in ("llm", "vlm"):
2025-03-25 11:36:12 +08:00
count = model_config["pool_size"]
while len(self.llm_models) >= count:
2025-03-18 18:09:19 +08:00
oldest_model = self.llm_models.pop(0)
old_instance = self.backends.pop(oldest_model, None)
if old_instance:
await old_instance.close()
self.backends[model_name] = LlmClientBackend(model_config)
self.llm_models.append(model_name)
elif model_config["type"] == "llama.cpp":
self.backends[model_name] = TestBackend(model_config)
elif model_config["type"] == "vision_model":
self.backends[model_name] = VisionModelBackend(model_config)
elif model_config["type"] == "tts":
self.backends[model_name] = TtsClientBackend(model_config)
elif model_config["type"] == "asr":
self.backends[model_name] = ASRClientBackend(model_config)
else:
return None
return self.backends.get(model_name)
2025-02-13 18:10:46 +08:00
async def initialize():
global config
model_list = GetModelList(
host=config.data["server"]["host"],
port=config.data["server"]["port"]
)
await model_list.get_model_list(required_mem=0)
config = Config()
dispatcher = ModelDispatcher()
return dispatcher
_dispatcher = asyncio.run(initialize())
2025-02-13 18:10:46 +08:00
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, body: ChatCompletionRequest):
2025-03-18 18:09:19 +08:00
backend = await _dispatcher.get_backend(body.model)
2025-02-13 18:10:46 +08:00
if not backend:
raise HTTPException(
status_code=400,
detail=f"Unsupported model: {body.model}"
)
try:
print(f"Received request: {body.model_dump()}")
if body.stream:
chunk_generator = await backend.generate(body)
if not chunk_generator:
raise HTTPException(
status_code=500,
detail="Failed to generate stream response"
)
async def format_stream():
try:
async for chunk in chunk_generator:
if isinstance(chunk, dict):
chunk_dict = chunk
else:
chunk_dict = chunk.model_dump()
json_chunk = json.dumps(chunk_dict, ensure_ascii=False)
print(f"Sending chunk: {json_chunk}")
yield f"data: {json_chunk}\n\n"
2025-02-18 17:13:25 +08:00
except asyncio.CancelledError:
2025-02-20 17:40:30 +08:00
logger.warning("Client disconnected early, terminating inference...")
2025-02-18 17:13:25 +08:00
if backend and isinstance(backend, LlmClientBackend):
for task in backend._active_tasks:
task.cancel()
raise
finally:
2025-02-20 17:40:30 +08:00
logger.debug("Stream connection closed")
2025-02-13 18:10:46 +08:00
return StreamingResponse(
format_stream(),
media_type="text/event-stream"
)
else:
response = await backend.generate(body)
print(f"Sending response: {response}")
return JSONResponse(content=response)
except HTTPException as he:
raise he
except Exception as e:
logger.error(f"Processing error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
2025-02-13 18:32:46 +08:00
@app.post("/v1/completions")
async def create_completion(request: Request, body: CompletionRequest):
chat_request = ChatCompletionRequest(
model=body.model,
messages=[Message(role="user", content=body.prompt)],
temperature=body.temperature,
max_tokens=body.max_tokens,
2025-02-18 11:40:03 +08:00
top_p=body.top_p,
2025-02-13 18:32:46 +08:00
stream=body.stream
)
2025-03-18 18:09:19 +08:00
backend = await _dispatcher.get_backend(chat_request.model)
2025-02-13 18:32:46 +08:00
if not backend:
raise HTTPException(status_code=400, detail=f"Unsupported model: {chat_request.model}")
try:
if body.stream:
chunk_generator = await backend.generate(chat_request)
async def convert_stream():
async for chunk in chunk_generator:
2025-02-20 17:40:30 +08:00
# Convert format and serialize to JSON string
2025-02-13 18:32:46 +08:00
completion_chunk = {
"id": chunk.get("id", f"cmpl-{uuid.uuid4()}"),
"object": "text_completion.chunk",
"created": chunk.get("created", int(time.time())),
"model": chat_request.model,
"choices": [{
"text": chunk["choices"][0]["delta"].get("content", ""),
"index": 0,
"logprobs": None,
"finish_reason": chunk["choices"][0].get("finish_reason")
}]
}
yield f"data: {json.dumps(completion_chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
convert_stream(),
media_type="text/event-stream"
)
else:
chat_response = await backend.generate(chat_request)
return JSONResponse({
"id": f"cmpl-{uuid.uuid4()}",
"object": "text_completion",
"created": int(time.time()),
"model": chat_request.model,
"choices": [{
"text": chat_response["choices"][0]["message"]["content"],
"index": 0,
"logprobs": None,
"finish_reason": "stop"
}],
"usage": chat_response.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
})
except Exception as e:
logger.error(f"Completion error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
2025-02-20 18:43:45 +08:00
@app.post("/v1/audio/speech")
async def create_speech(request: Request):
try:
request_data = await request.json()
2025-03-18 18:09:19 +08:00
backend = await _dispatcher.get_backend(request_data.get("model"))
2025-02-20 18:43:45 +08:00
if not backend:
raise HTTPException(status_code=400, detail="Unsupported model")
2025-02-21 18:02:24 +08:00
audio_stream = backend.generate_speech(
input_text=request_data.get("input"),
voice=request_data.get("voice", "alloy"),
format=request_data.get("response_format", "mp3")
2025-02-20 18:43:45 +08:00
)
return StreamingResponse(
2025-02-21 18:02:24 +08:00
audio_stream,
media_type=f"audio/{request_data.get('response_format', 'mp3')}",
headers={"Content-Disposition": f'attachment; filename="speech.{request_data.get("response_format", "mp3")}"'}
2025-02-20 18:43:45 +08:00
)
except Exception as e:
logger.error(f"Speech generation error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/audio/transcriptions")
async def create_transcription(
file: UploadFile = File(...),
model: str = Form(...),
language: str = Form(None),
prompt: str = Form(""),
response_format: str = Form("json")
):
try:
2025-03-18 18:09:19 +08:00
backend = await _dispatcher.get_backend(model)
2025-02-20 18:43:45 +08:00
if not backend:
raise HTTPException(status_code=400, detail="Unsupported model")
audio_data = await file.read()
transcription = await backend.create_transcription(
audio_data,
language=language,
prompt=prompt
)
return JSONResponse(content={
"text": transcription,
"task": "transcribe",
"language": language,
"duration": 0
})
except Exception as e:
logger.error(f"Transcription error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/audio/translations")
async def create_translation(
file: UploadFile = File(...),
model: str = Form(...),
prompt: str = Form(""),
response_format: str = Form("json")
):
try:
2025-03-18 18:09:19 +08:00
backend = await _dispatcher.get_backend(model)
2025-02-20 18:43:45 +08:00
if not backend:
raise HTTPException(status_code=400, detail="Unsupported model")
audio_data = await file.read()
translation = await backend.create_translation(
audio_data,
prompt=prompt
)
return JSONResponse(content={
"text": translation,
"task": "translate",
"duration": 0
})
except Exception as e:
logger.error(f"Translation error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models")
async def list_models():
models_info = []
2025-03-25 11:36:12 +08:00
for model_name in config.data["models"].keys():
model_config = config.data["models"].get(model_name, {})
models_info.append({
"id": model_name,
"object": "model",
"created": model_config.get("created", 0),
"owned_by": model_config.get("owner", "user"),
"permission": [],
"root": model_config.get("root", "")
})
return {
"data": models_info,
"object": "list"
}
2025-02-13 18:10:46 +08:00
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
logging.getLogger().handlers[0].flush()