|
|
|
@@ -20,7 +20,7 @@ from flask_login import login_required |
|
|
|
from api.db.services.dialog_service import DialogService, ConversationService
|
|
|
|
from api.db import LLMType
|
|
|
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
|
|
from api.db.services.llm_service import LLMService, LLMBundle
|
|
|
|
from api.db.services.llm_service import LLMService, LLMBundle, TenantLLMService
|
|
|
|
from api.settings import access_logger, stat_logger, retrievaler, chat_logger
|
|
|
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
|
|
|
from api.utils import get_uuid
|
|
|
|
@@ -184,8 +184,11 @@ def chat(dialog, messages, **kwargs): |
|
|
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
|
|
|
llm = LLMService.query(llm_name=dialog.llm_id)
|
|
|
|
if not llm:
|
|
|
|
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
|
|
|
llm = llm[0]
|
|
|
|
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
|
|
|
|
if not llm:
|
|
|
|
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
|
|
|
max_tokens = 1024
|
|
|
|
else: max_tokens = llm[0].max_tokens
|
|
|
|
questions = [m["content"] for m in messages if m["role"] == "user"]
|
|
|
|
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
|
|
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
|
|
|
@@ -227,11 +230,11 @@ def chat(dialog, messages, **kwargs): |
|
|
|
gen_conf = dialog.llm_setting
|
|
|
|
msg = [{"role": m["role"], "content": m["content"]}
|
|
|
|
for m in messages if m["role"] != "system"]
|
|
|
|
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
|
|
|
|
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
|
|
|
if "max_tokens" in gen_conf:
|
|
|
|
gen_conf["max_tokens"] = min(
|
|
|
|
gen_conf["max_tokens"],
|
|
|
|
llm.max_tokens - used_token_count)
|
|
|
|
max_tokens - used_token_count)
|
|
|
|
answer = chat_mdl.chat(
|
|
|
|
prompt_config["system"].format(
|
|
|
|
**kwargs), msg, gen_conf)
|