瀏覽代碼

Fix: The max tokens defined by the tenant are not used (#4297) (#2817) (#5066)

### What problem does this PR solve?

Fix: The max tokens defined by the tenant are not used (#4297) (#2817)


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
tags/v0.17.0
zhxlp 8 月之前
父節點
當前提交
00c7ddbc9b
沒有連結到貢獻者的電子郵件帳戶。
共有 2 個檔案被更改,包括 15 行新增21 行删除
  1. 6
    15
      api/db/services/dialog_service.py
  2. 9
    6
      api/db/services/llm_service.py

+ 6
- 15
api/db/services/dialog_service.py 查看文件

@@ -29,7 +29,7 @@ from api.db.db_models import Dialog, DB
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.db.services.llm_service import TenantLLMService, LLMBundle
from api import settings
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
from rag.app.resume import forbidden_select_fields4resume
@@ -172,21 +172,12 @@ def chat(dialog, messages, stream=True, **kwargs):

chat_start_ts = timer()

# Get llm model name and model provider name
llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id)

# Get llm model instance by model and provide name
llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider)

if not llm:
# Model name is provided by tenant, but not system built-in
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider)
if not llm:
raise LookupError("LLM(%s) not found" % dialog.llm_id)
max_tokens = 8192
if llm_id2llm_type(dialog.llm_id) == "image2text":
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
max_tokens = llm[0].max_tokens
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)

max_tokens = llm_model_config.get("max_tokens", 8192)

check_llm_ts = timer()


+ 9
- 6
api/db/services/llm_service.py 查看文件

@@ -86,8 +86,7 @@ class TenantLLMService(CommonService):

@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type,
llm_name=None, lang="Chinese"):
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
@@ -124,7 +123,13 @@ class TenantLLMService(CommonService):
if not mdlnm:
raise LookupError(f"Type of {llm_type} model is not set.")
raise LookupError("Model({}) not authorized".format(mdlnm))
return model_config

@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type,
llm_name=None, lang="Chinese"):
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return
@@ -228,10 +233,8 @@ class LLMBundle(object):
tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find model for {}/{}/{}".format(
tenant_id, llm_type, llm_name)
self.max_length = 8192
for lm in LLMService.query(llm_name=llm_name):
self.max_length = lm.max_tokens
break
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
self.max_length = model_config.get("max_tokens", 8192)

def encode(self, texts: list):
embeddings, used_tokens = self.mdl.encode(texts)

Loading…
取消
儲存