Browse Source

Fix: issue for tavily only in a assistant. (#8076)

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.19.1
Kevin Hu 4 months ago
parent
commit
91804f28f1
No account linked to committer's email address
1 changed files with 85 additions and 101 deletions
  1. 85
    101
      api/db/services/dialog_service.py

+ 85
- 101
api/db/services/dialog_service.py View File

@@ -127,6 +127,31 @@ def chat_solo(dialog, messages, stream=True):
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}


def get_models(dialog):
embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
if len(embedding_list) > 1:
raise Exception("**ERROR**: Knowledge bases use different embedding models.")

if embedding_list:
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_list[0])

if llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)

if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)

if dialog.prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl


BAD_CITATION_PATTERNS = [
re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12)
re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12]
@@ -134,10 +159,38 @@ BAD_CITATION_PATTERNS = [
re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12
]

def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
max_index = len(kbinfos["chunks"])

def safe_add(i):
if 0 <= i < max_index:
idx.add(i)
return True
return False

def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0):
nonlocal answer

def replacement(match):
try:
i = int(match.group(group_index))
if safe_add(i):
return f"[{repl(i)}]"
except Exception:
pass
return match.group(0)

answer = re.sub(pattern, replacement, answer, flags=flags)

for pattern in BAD_CITATION_PATTERNS:
find_and_replace(pattern)

return answer, idx


def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
if not dialog.kb_ids:
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
for ans in chat_solo(dialog, messages, stream):
yield ans
return
@@ -162,45 +215,19 @@ def chat(dialog, messages, stream=True, **kwargs):
langfuse.trace = langfuse_tracer.trace(name=f"{dialog.name}-{llm_model_config['llm_name']}")

check_langfuse_tracer_ts = timer()

kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
if len(embedding_list) != 1:
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}

embedding_model_name = embedding_list[0]
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
if toolcall_session and tools:
chat_mdl.bind_tools(toolcall_session, tools)
bind_models_ts = timer()

retriever = settings.retrievaler

questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
if "doc_ids" in messages[-1]:
attachments = messages[-1]["doc_ids"]

create_retriever_ts = timer()

embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embedding_model_name)

bind_embedding_ts = timer()

if llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
if toolcall_session and tools:
chat_mdl.bind_tools(toolcall_session, tools)

bind_llm_ts = timer()

prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
tts_mdl = None
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
# try to use sql if field mapping is good to go
if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
@@ -225,26 +252,18 @@ def chat(dialog, messages, stream=True, **kwargs):
if prompt_config.get("cross_languages"):
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]

refine_question_ts = timer()
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])

rerank_mdl = None
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
refine_question_ts = timer()

bind_reranker_ts = timer()
generate_keyword_ts = bind_reranker_ts
thought = ""
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}

if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
knowledges = []
else:
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
generate_keyword_ts = timer()

tenant_ids = list(set([kb.tenant_id for kb in kbs]))

knowledges = []
if prompt_config.get("reasoning", False):
reasoner = DeepResearcher(
@@ -260,21 +279,22 @@ def chat(dialog, messages, stream=True, **kwargs):
elif stream:
yield think
else:
kbinfos = retriever.retrieval(
" ".join(questions),
embd_mdl,
tenant_ids,
dialog.kb_ids,
1,
dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs),
)
if embd_mdl:
kbinfos = retriever.retrieval(
" ".join(questions),
embd_mdl,
tenant_ids,
dialog.kb_ids,
1,
dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=attachments,
top=dialog.top_k,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs),
)
if prompt_config.get("tavily_api_key"):
tav = Tavily(prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(" ".join(questions))
@@ -310,36 +330,8 @@ def chat(dialog, messages, stream=True, **kwargs):
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)

def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
max_index = len(kbinfos["chunks"])

def safe_add(i):
if 0 <= i < max_index:
idx.add(i)
return True
return False

def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0):
nonlocal answer

def replacement(match):
try:
i = int(match.group(group_index))
if safe_add(i):
return f"[{repl(i)}]"
except Exception:
pass
return match.group(0)

answer = re.sub(pattern, replacement, answer, flags=flags)

for pattern in BAD_CITATION_PATTERNS:
find_and_replace(pattern)

return answer, idx

def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer
nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer

refs = []
ans = answer.split("</think>")
@@ -350,7 +342,7 @@ def chat(dialog, messages, stream=True, **kwargs):

if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
idx = set([])
if not re.search(r"\[ID:([0-9]+)\]", answer):
if embd_mdl and not re.search(r"\[ID:([0-9]+)\]", answer):
answer, idx = retriever.insert_citations(
answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]],
@@ -385,13 +377,9 @@ def chat(dialog, messages, stream=True, **kwargs):
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000
create_retriever_time_cost = (create_retriever_ts - check_langfuse_tracer_ts) * 1000
bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
bind_embedding_time_cost = (bind_models_ts - check_langfuse_tracer_ts) * 1000
refine_question_time_cost = (refine_question_ts - bind_models_ts) * 1000
retrieval_time_cost = (retrieval_ts - refine_question_ts) * 1000
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000

tk_num = num_tokens_from_string(think + answer)
@@ -402,12 +390,8 @@ def chat(dialog, messages, stream=True, **kwargs):
f" - Total: {total_time_cost:.1f}ms\n"
f" - Check LLM: {check_llm_time_cost:.1f}ms\n"
f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n"
f" - Create retriever: {create_retriever_time_cost:.1f}ms\n"
f" - Bind embedding: {bind_embedding_time_cost:.1f}ms\n"
f" - Bind LLM: {bind_llm_time_cost:.1f}ms\n"
f" - Multi-turn optimization: {refine_question_time_cost:.1f}ms\n"
f" - Bind reranker: {bind_reranker_time_cost:.1f}ms\n"
f" - Generate keyword: {generate_keyword_time_cost:.1f}ms\n"
f" - Bind models: {bind_embedding_time_cost:.1f}ms\n"
f" - Query refinement(LLM): {refine_question_time_cost:.1f}ms\n"
f" - Retrieval: {retrieval_time_cost:.1f}ms\n"
f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n"
"## Token usage:\n"

Loading…
Cancel
Save