From 85cc20d391cd2c03a922d19774cdcf7c3be23c6c Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Fri, 14 Feb 2025 16:37:25 +0800 Subject: [PATCH] [perf]: Optimize parameter passing method --- api_server.py | 28 +++++++++++++++------------- llm_client.py | 14 +++++++------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/api_server.py b/api_server.py index 47b91d3..2cb1d5f 100644 --- a/api_server.py +++ b/api_server.py @@ -180,19 +180,21 @@ class LlmClientBackend(BaseModelBackend): loop = asyncio.get_event_loop() await loop.run_in_executor( None, - client.setup, - { - "model": self.config["model_name"], - "response_format": "llm.utf-8.stream", - "input": "llm.utf-8", - "enoutput": True, - "max_token_len": request.max_tokens, - "temperature": request.temperature, - "prompt": next( - (m.content for m in request.messages if m.role == "system"), - self.config.get("system_prompt", "You are a helpful assistant") - ) - } + lambda: client.setup( + self.config["object"], + { + "model": self.config["model_name"], + "response_format": self.config["response_format"], + "input": self.config["input"], + "enoutput": True, + "max_token_len": request.max_tokens, + "temperature": request.temperature, + "prompt": next( + (m.content for m in request.messages if m.role == "system"), + self.config.get("system_prompt", "You are a helpful assistant") + ) + } + ) ) return client diff --git a/llm_client.py b/llm_client.py index 58a0b39..2b0abe2 100644 --- a/llm_client.py +++ b/llm_client.py @@ -39,14 +39,14 @@ class LLMClient: self.sock.close() self.sock = None - def _send_request(self, action: str, data: dict) -> str: + def _send_request(self, action: str, object: str, data: dict) -> str: """通用请求发送方法""" request_id = str(uuid.uuid4()) payload = { "request_id": request_id, "work_id": self.work_id or "llm", "action": action, - "object": "llm.setup" if action == "setup" else "llm.utf-8", + "object": object, "data": data } @@ -59,14 +59,14 @@ class LLMClient: self.sock.sendall(json.dumps(payload, ensure_ascii=False).encode('utf-8')) return request_id - def setup(self, model_config: dict) -> dict: + def setup(self, object: str, model_config: dict) -> dict: if not self.sock: self._connect() - request_id = self._send_request("setup", model_config) + request_id = self._send_request("setup", object, model_config) return self._wait_response(request_id) def inference_stream(self, query: str) -> Generator[str, None, None]: - request_id = self._send_request("inference", query) + request_id = self._send_request("inference", "llm.utf-8", query) while True: response = json.loads(self.sock.recv(4096).decode()) @@ -79,7 +79,7 @@ class LLMClient: break def exit(self) -> dict: - request_id = self._send_request("exit", {}) + request_id = self._send_request("exit", "llm.utf-8", {}) result = self._wait_response(request_id) self._initialized = False return result @@ -104,7 +104,7 @@ class LLMClient: # 使用示例 if __name__ == "__main__": with LLMClient(host='192.168.20.183') as client: - setup_response = client.setup({ + setup_response = client.setup("llm.setup", { "model": "deepseek-r1-1.5B-ax630c", "response_format": "llm.utf-8.stream", "input": "llm.utf-8",