You've already forked ModuleLLM-OpenAI-Plugin
mirror of
https://github.com/m5stack/ModuleLLM-OpenAI-Plugin.git
synced 2026-05-20 11:37:26 -07:00
[perf] Optimize token statistics method
This commit is contained in:
+3
-2
@@ -38,6 +38,9 @@ class Config:
|
||||
config_path = os.path.join(current_dir, "config", "config.yaml")
|
||||
with open(config_path) as f:
|
||||
self.data = yaml.safe_load(f)
|
||||
|
||||
tiktoken_cache_dir = os.path.join(current_dir, "cache")
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
||||
|
||||
config = Config()
|
||||
|
||||
@@ -75,8 +78,6 @@ class ModelDispatcher:
|
||||
await old_instance.close()
|
||||
self.backends[model_name] = LlmClientBackend(model_config)
|
||||
self.llm_models.append(model_name)
|
||||
elif model_config["type"] == "llama.cpp":
|
||||
self.backends[model_name] = TestBackend(model_config)
|
||||
elif model_config["type"] == "vision_model":
|
||||
self.backends[model_name] = VisionModelBackend(model_config)
|
||||
elif model_config["type"] == "tts":
|
||||
|
||||
@@ -12,6 +12,7 @@ import logging
|
||||
from fastapi import HTTPException
|
||||
from typing import Union, List
|
||||
from services.memory_check import MemoryChecker
|
||||
import tiktoken
|
||||
|
||||
class LlmClientBackend(BaseModelBackend):
|
||||
def __init__(self, model_config):
|
||||
@@ -28,6 +29,7 @@ class LlmClientBackend(BaseModelBackend):
|
||||
host=self.config["host"],
|
||||
port=self.config["port"]
|
||||
)
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
async def _parse_content(self, content: Union[str, List[ContentItem]], base64_images: list) -> str:
|
||||
text_parts = []
|
||||
@@ -170,6 +172,10 @@ class LlmClientBackend(BaseModelBackend):
|
||||
await self._release_client(client)
|
||||
self.logger.debug(f"Inference stopped | ClientID:{id(client)}")
|
||||
|
||||
def _count_tokens(self, text: str) -> int:
|
||||
"""Count the number of tokens in a given text."""
|
||||
return len(self.tokenizer.encode(text))
|
||||
|
||||
def _truncate_history(self, messages: List[Message]) -> List[Message]:
|
||||
"""Truncate history to fit model context window"""
|
||||
total_length = 0
|
||||
@@ -180,8 +186,16 @@ class LlmClientBackend(BaseModelBackend):
|
||||
if msg.role == "system": # Always keep system messages
|
||||
keep_messages.insert(0, msg)
|
||||
continue
|
||||
|
||||
msg_length = len(msg.content)
|
||||
if isinstance(msg.content, list):
|
||||
msg_length = 0
|
||||
for item in msg.content:
|
||||
if item.type == "text":
|
||||
msg_length += self._count_tokens(item.text)
|
||||
total_length += msg_length
|
||||
keep_messages.insert(0, msg)
|
||||
break
|
||||
else:
|
||||
msg_length = self._count_tokens(msg.content)
|
||||
if total_length + msg_length > self.MAX_CONTEXT_LENGTH:
|
||||
break
|
||||
total_length += msg_length
|
||||
@@ -225,7 +239,6 @@ class LlmClientBackend(BaseModelBackend):
|
||||
final_query.append(system_prompt.strip())
|
||||
if base64_images:
|
||||
pass
|
||||
# final_query.append("\n".join([f"[IMAGE:{img[:20]}...]" for img in base64_images]))
|
||||
final_query.append("\n".join(query_lines))
|
||||
|
||||
query = "\n\n".join(filter(None, final_query))
|
||||
|
||||
+100256
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user