[perf] Optimize token statistics method

This commit is contained in:
LittleMouse
2025-04-09 15:56:50 +08:00
parent 82997ac42b
commit 423fa93a70
3 changed files with 100275 additions and 5 deletions
+3 -2
View File
@@ -38,6 +38,9 @@ class Config:
config_path = os.path.join(current_dir, "config", "config.yaml")
with open(config_path) as f:
self.data = yaml.safe_load(f)
tiktoken_cache_dir = os.path.join(current_dir, "cache")
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
config = Config()
@@ -75,8 +78,6 @@ class ModelDispatcher:
await old_instance.close()
self.backends[model_name] = LlmClientBackend(model_config)
self.llm_models.append(model_name)
elif model_config["type"] == "llama.cpp":
self.backends[model_name] = TestBackend(model_config)
elif model_config["type"] == "vision_model":
self.backends[model_name] = VisionModelBackend(model_config)
elif model_config["type"] == "tts":
+16 -3
View File
@@ -12,6 +12,7 @@ import logging
from fastapi import HTTPException
from typing import Union, List
from services.memory_check import MemoryChecker
import tiktoken
class LlmClientBackend(BaseModelBackend):
def __init__(self, model_config):
@@ -28,6 +29,7 @@ class LlmClientBackend(BaseModelBackend):
host=self.config["host"],
port=self.config["port"]
)
self.tokenizer = tiktoken.get_encoding("cl100k_base")
async def _parse_content(self, content: Union[str, List[ContentItem]], base64_images: list) -> str:
text_parts = []
@@ -170,6 +172,10 @@ class LlmClientBackend(BaseModelBackend):
await self._release_client(client)
self.logger.debug(f"Inference stopped | ClientID:{id(client)}")
def _count_tokens(self, text: str) -> int:
"""Count the number of tokens in a given text."""
return len(self.tokenizer.encode(text))
def _truncate_history(self, messages: List[Message]) -> List[Message]:
"""Truncate history to fit model context window"""
total_length = 0
@@ -180,8 +186,16 @@ class LlmClientBackend(BaseModelBackend):
if msg.role == "system": # Always keep system messages
keep_messages.insert(0, msg)
continue
msg_length = len(msg.content)
if isinstance(msg.content, list):
msg_length = 0
for item in msg.content:
if item.type == "text":
msg_length += self._count_tokens(item.text)
total_length += msg_length
keep_messages.insert(0, msg)
break
else:
msg_length = self._count_tokens(msg.content)
if total_length + msg_length > self.MAX_CONTEXT_LENGTH:
break
total_length += msg_length
@@ -225,7 +239,6 @@ class LlmClientBackend(BaseModelBackend):
final_query.append(system_prompt.strip())
if base64_images:
pass
# final_query.append("\n".join([f"[IMAGE:{img[:20]}...]" for img in base64_images]))
final_query.append("\n".join(query_lines))
query = "\n\n".join(filter(None, final_query))
File diff suppressed because it is too large Load Diff