|
|
|
@@ -127,6 +127,31 @@ def chat_solo(dialog, messages, stream=True): |
|
|
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} |
|
|
|
|
|
|
|
|
|
|
|
def get_models(dialog): |
|
|
|
embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None |
|
|
|
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) |
|
|
|
embedding_list = list(set([kb.embd_id for kb in kbs])) |
|
|
|
if len(embedding_list) > 1: |
|
|
|
raise Exception("**ERROR**: Knowledge bases use different embedding models.") |
|
|
|
|
|
|
|
if embedding_list: |
|
|
|
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0]) |
|
|
|
if not embd_mdl: |
|
|
|
raise LookupError("Embedding model(%s) not found" % embedding_list[0]) |
|
|
|
|
|
|
|
if llm_id2llm_type(dialog.llm_id) == "image2text": |
|
|
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) |
|
|
|
else: |
|
|
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) |
|
|
|
|
|
|
|
if dialog.rerank_id: |
|
|
|
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) |
|
|
|
|
|
|
|
if dialog.prompt_config.get("tts"): |
|
|
|
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) |
|
|
|
return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl |
|
|
|
|
|
|
|
|
|
|
|
BAD_CITATION_PATTERNS = [ |
|
|
|
re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12) |
|
|
|
re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12] |
|
|
|
@@ -134,10 +159,38 @@ BAD_CITATION_PATTERNS = [ |
|
|
|
re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12 |
|
|
|
] |
|
|
|
|
|
|
|
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): |
|
|
|
max_index = len(kbinfos["chunks"]) |
|
|
|
|
|
|
|
def safe_add(i): |
|
|
|
if 0 <= i < max_index: |
|
|
|
idx.add(i) |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0): |
|
|
|
nonlocal answer |
|
|
|
|
|
|
|
def replacement(match): |
|
|
|
try: |
|
|
|
i = int(match.group(group_index)) |
|
|
|
if safe_add(i): |
|
|
|
return f"[{repl(i)}]" |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
return match.group(0) |
|
|
|
|
|
|
|
answer = re.sub(pattern, replacement, answer, flags=flags) |
|
|
|
|
|
|
|
for pattern in BAD_CITATION_PATTERNS: |
|
|
|
find_and_replace(pattern) |
|
|
|
|
|
|
|
return answer, idx |
|
|
|
|
|
|
|
|
|
|
|
def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." |
|
|
|
if not dialog.kb_ids: |
|
|
|
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): |
|
|
|
for ans in chat_solo(dialog, messages, stream): |
|
|
|
yield ans |
|
|
|
return |
|
|
|
@@ -162,45 +215,19 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
langfuse.trace = langfuse_tracer.trace(name=f"{dialog.name}-{llm_model_config['llm_name']}") |
|
|
|
|
|
|
|
check_langfuse_tracer_ts = timer() |
|
|
|
|
|
|
|
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) |
|
|
|
embedding_list = list(set([kb.embd_id for kb in kbs])) |
|
|
|
if len(embedding_list) != 1: |
|
|
|
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} |
|
|
|
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} |
|
|
|
|
|
|
|
embedding_model_name = embedding_list[0] |
|
|
|
kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog) |
|
|
|
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools") |
|
|
|
if toolcall_session and tools: |
|
|
|
chat_mdl.bind_tools(toolcall_session, tools) |
|
|
|
bind_models_ts = timer() |
|
|
|
|
|
|
|
retriever = settings.retrievaler |
|
|
|
|
|
|
|
questions = [m["content"] for m in messages if m["role"] == "user"][-3:] |
|
|
|
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None |
|
|
|
if "doc_ids" in messages[-1]: |
|
|
|
attachments = messages[-1]["doc_ids"] |
|
|
|
|
|
|
|
create_retriever_ts = timer() |
|
|
|
|
|
|
|
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name) |
|
|
|
if not embd_mdl: |
|
|
|
raise LookupError("Embedding model(%s) not found" % embedding_model_name) |
|
|
|
|
|
|
|
bind_embedding_ts = timer() |
|
|
|
|
|
|
|
if llm_id2llm_type(dialog.llm_id) == "image2text": |
|
|
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) |
|
|
|
else: |
|
|
|
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) |
|
|
|
toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools") |
|
|
|
if toolcall_session and tools: |
|
|
|
chat_mdl.bind_tools(toolcall_session, tools) |
|
|
|
|
|
|
|
bind_llm_ts = timer() |
|
|
|
|
|
|
|
prompt_config = dialog.prompt_config |
|
|
|
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) |
|
|
|
tts_mdl = None |
|
|
|
if prompt_config.get("tts"): |
|
|
|
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) |
|
|
|
# try to use sql if field mapping is good to go |
|
|
|
if field_map: |
|
|
|
logging.debug("Use SQL to retrieval:{}".format(questions[-1])) |
|
|
|
@@ -225,26 +252,18 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
if prompt_config.get("cross_languages"): |
|
|
|
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] |
|
|
|
|
|
|
|
refine_question_ts = timer() |
|
|
|
if prompt_config.get("keyword", False): |
|
|
|
questions[-1] += keyword_extraction(chat_mdl, questions[-1]) |
|
|
|
|
|
|
|
rerank_mdl = None |
|
|
|
if dialog.rerank_id: |
|
|
|
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) |
|
|
|
refine_question_ts = timer() |
|
|
|
|
|
|
|
bind_reranker_ts = timer() |
|
|
|
generate_keyword_ts = bind_reranker_ts |
|
|
|
thought = "" |
|
|
|
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} |
|
|
|
|
|
|
|
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: |
|
|
|
knowledges = [] |
|
|
|
else: |
|
|
|
if prompt_config.get("keyword", False): |
|
|
|
questions[-1] += keyword_extraction(chat_mdl, questions[-1]) |
|
|
|
generate_keyword_ts = timer() |
|
|
|
|
|
|
|
tenant_ids = list(set([kb.tenant_id for kb in kbs])) |
|
|
|
|
|
|
|
knowledges = [] |
|
|
|
if prompt_config.get("reasoning", False): |
|
|
|
reasoner = DeepResearcher( |
|
|
|
@@ -260,21 +279,22 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
elif stream: |
|
|
|
yield think |
|
|
|
else: |
|
|
|
kbinfos = retriever.retrieval( |
|
|
|
" ".join(questions), |
|
|
|
embd_mdl, |
|
|
|
tenant_ids, |
|
|
|
dialog.kb_ids, |
|
|
|
1, |
|
|
|
dialog.top_n, |
|
|
|
dialog.similarity_threshold, |
|
|
|
dialog.vector_similarity_weight, |
|
|
|
doc_ids=attachments, |
|
|
|
top=dialog.top_k, |
|
|
|
aggs=False, |
|
|
|
rerank_mdl=rerank_mdl, |
|
|
|
rank_feature=label_question(" ".join(questions), kbs), |
|
|
|
) |
|
|
|
if embd_mdl: |
|
|
|
kbinfos = retriever.retrieval( |
|
|
|
" ".join(questions), |
|
|
|
embd_mdl, |
|
|
|
tenant_ids, |
|
|
|
dialog.kb_ids, |
|
|
|
1, |
|
|
|
dialog.top_n, |
|
|
|
dialog.similarity_threshold, |
|
|
|
dialog.vector_similarity_weight, |
|
|
|
doc_ids=attachments, |
|
|
|
top=dialog.top_k, |
|
|
|
aggs=False, |
|
|
|
rerank_mdl=rerank_mdl, |
|
|
|
rank_feature=label_question(" ".join(questions), kbs), |
|
|
|
) |
|
|
|
if prompt_config.get("tavily_api_key"): |
|
|
|
tav = Tavily(prompt_config["tavily_api_key"]) |
|
|
|
tav_res = tav.retrieve_chunks(" ".join(questions)) |
|
|
|
@@ -310,36 +330,8 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count) |
|
|
|
|
|
|
|
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): |
|
|
|
max_index = len(kbinfos["chunks"]) |
|
|
|
|
|
|
|
def safe_add(i): |
|
|
|
if 0 <= i < max_index: |
|
|
|
idx.add(i) |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0): |
|
|
|
nonlocal answer |
|
|
|
|
|
|
|
def replacement(match): |
|
|
|
try: |
|
|
|
i = int(match.group(group_index)) |
|
|
|
if safe_add(i): |
|
|
|
return f"[{repl(i)}]" |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
return match.group(0) |
|
|
|
|
|
|
|
answer = re.sub(pattern, replacement, answer, flags=flags) |
|
|
|
|
|
|
|
for pattern in BAD_CITATION_PATTERNS: |
|
|
|
find_and_replace(pattern) |
|
|
|
|
|
|
|
return answer, idx |
|
|
|
|
|
|
|
def decorate_answer(answer): |
|
|
|
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer |
|
|
|
nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer |
|
|
|
|
|
|
|
refs = [] |
|
|
|
ans = answer.split("</think>") |
|
|
|
@@ -350,7 +342,7 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
|
|
|
|
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): |
|
|
|
idx = set([]) |
|
|
|
if not re.search(r"\[ID:([0-9]+)\]", answer): |
|
|
|
if embd_mdl and not re.search(r"\[ID:([0-9]+)\]", answer): |
|
|
|
answer, idx = retriever.insert_citations( |
|
|
|
answer, |
|
|
|
[ck["content_ltks"] for ck in kbinfos["chunks"]], |
|
|
|
@@ -385,13 +377,9 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000 |
|
|
|
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000 |
|
|
|
check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000 |
|
|
|
create_retriever_time_cost = (create_retriever_ts - check_langfuse_tracer_ts) * 1000 |
|
|
|
bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000 |
|
|
|
bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000 |
|
|
|
refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000 |
|
|
|
bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000 |
|
|
|
generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000 |
|
|
|
retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000 |
|
|
|
bind_embedding_time_cost = (bind_models_ts - check_langfuse_tracer_ts) * 1000 |
|
|
|
refine_question_time_cost = (refine_question_ts - bind_models_ts) * 1000 |
|
|
|
retrieval_time_cost = (retrieval_ts - refine_question_ts) * 1000 |
|
|
|
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 |
|
|
|
|
|
|
|
tk_num = num_tokens_from_string(think + answer) |
|
|
|
@@ -402,12 +390,8 @@ def chat(dialog, messages, stream=True, **kwargs): |
|
|
|
f" - Total: {total_time_cost:.1f}ms\n" |
|
|
|
f" - Check LLM: {check_llm_time_cost:.1f}ms\n" |
|
|
|
f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n" |
|
|
|
f" - Create retriever: {create_retriever_time_cost:.1f}ms\n" |
|
|
|
f" - Bind embedding: {bind_embedding_time_cost:.1f}ms\n" |
|
|
|
f" - Bind LLM: {bind_llm_time_cost:.1f}ms\n" |
|
|
|
f" - Multi-turn optimization: {refine_question_time_cost:.1f}ms\n" |
|
|
|
f" - Bind reranker: {bind_reranker_time_cost:.1f}ms\n" |
|
|
|
f" - Generate keyword: {generate_keyword_time_cost:.1f}ms\n" |
|
|
|
f" - Bind models: {bind_embedding_time_cost:.1f}ms\n" |
|
|
|
f" - Query refinement(LLM): {refine_question_time_cost:.1f}ms\n" |
|
|
|
f" - Retrieval: {retrieval_time_cost:.1f}ms\n" |
|
|
|
f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n" |
|
|
|
"## Token usage:\n" |