[perf]: Optimize parameter passing method

This commit is contained in:
LittleMouse
2025-02-14 16:37:25 +08:00
parent db0d24221f
commit 85cc20d391
2 changed files with 22 additions and 20 deletions
+15 -13
View File
@@ -180,19 +180,21 @@ class LlmClientBackend(BaseModelBackend):
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None,
client.setup,
{
"model": self.config["model_name"],
"response_format": "llm.utf-8.stream",
"input": "llm.utf-8",
"enoutput": True,
"max_token_len": request.max_tokens,
"temperature": request.temperature,
"prompt": next(
(m.content for m in request.messages if m.role == "system"),
self.config.get("system_prompt", "You are a helpful assistant")
)
}
lambda: client.setup(
self.config["object"],
{
"model": self.config["model_name"],
"response_format": self.config["response_format"],
"input": self.config["input"],
"enoutput": True,
"max_token_len": request.max_tokens,
"temperature": request.temperature,
"prompt": next(
(m.content for m in request.messages if m.role == "system"),
self.config.get("system_prompt", "You are a helpful assistant")
)
}
)
)
return client
+7 -7
View File
@@ -39,14 +39,14 @@ class LLMClient:
self.sock.close()
self.sock = None
def _send_request(self, action: str, data: dict) -> str:
def _send_request(self, action: str, object: str, data: dict) -> str:
"""通用请求发送方法"""
request_id = str(uuid.uuid4())
payload = {
"request_id": request_id,
"work_id": self.work_id or "llm",
"action": action,
"object": "llm.setup" if action == "setup" else "llm.utf-8",
"object": object,
"data": data
}
@@ -59,14 +59,14 @@ class LLMClient:
self.sock.sendall(json.dumps(payload, ensure_ascii=False).encode('utf-8'))
return request_id
def setup(self, model_config: dict) -> dict:
def setup(self, object: str, model_config: dict) -> dict:
if not self.sock:
self._connect()
request_id = self._send_request("setup", model_config)
request_id = self._send_request("setup", object, model_config)
return self._wait_response(request_id)
def inference_stream(self, query: str) -> Generator[str, None, None]:
request_id = self._send_request("inference", query)
request_id = self._send_request("inference", "llm.utf-8", query)
while True:
response = json.loads(self.sock.recv(4096).decode())
@@ -79,7 +79,7 @@ class LLMClient:
break
def exit(self) -> dict:
request_id = self._send_request("exit", {})
request_id = self._send_request("exit", "llm.utf-8", {})
result = self._wait_response(request_id)
self._initialized = False
return result
@@ -104,7 +104,7 @@ class LLMClient:
# 使用示例
if __name__ == "__main__":
with LLMClient(host='192.168.20.183') as client:
setup_response = client.setup({
setup_response = client.setup("llm.setup", {
"model": "deepseek-r1-1.5B-ax630c",
"response_format": "llm.utf-8.stream",
"input": "llm.utf-8",