您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

dialog_service.py 22KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import binascii
  18. import time
  19. from functools import partial
  20. import re
  21. from copy import deepcopy
  22. from timeit import default_timer as timer
  23. from agentic_reasoning import DeepResearcher
  24. from api.db import LLMType, ParserType, StatusEnum
  25. from api.db.db_models import Dialog, DB
  26. from api.db.services.common_service import CommonService
  27. from api.db.services.knowledgebase_service import KnowledgebaseService
  28. from api.db.services.llm_service import TenantLLMService, LLMBundle
  29. from api import settings
  30. from rag.app.resume import forbidden_select_fields4resume
  31. from rag.app.tag import label_question
  32. from rag.nlp.search import index_name
  33. from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question
  34. from rag.utils import rmSpace, num_tokens_from_string
  35. from rag.utils.tavily_conn import Tavily
  36. class DialogService(CommonService):
  37. model = Dialog
  38. @classmethod
  39. @DB.connection_context()
  40. def get_list(cls, tenant_id,
  41. page_number, items_per_page, orderby, desc, id, name):
  42. chats = cls.model.select()
  43. if id:
  44. chats = chats.where(cls.model.id == id)
  45. if name:
  46. chats = chats.where(cls.model.name == name)
  47. chats = chats.where(
  48. (cls.model.tenant_id == tenant_id)
  49. & (cls.model.status == StatusEnum.VALID.value)
  50. )
  51. if desc:
  52. chats = chats.order_by(cls.model.getter_by(orderby).desc())
  53. else:
  54. chats = chats.order_by(cls.model.getter_by(orderby).asc())
  55. chats = chats.paginate(page_number, items_per_page)
  56. return list(chats.dicts())
  57. def chat_solo(dialog, messages, stream=True):
  58. if llm_id2llm_type(dialog.llm_id) == "image2text":
  59. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  60. else:
  61. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  62. prompt_config = dialog.prompt_config
  63. tts_mdl = None
  64. if prompt_config.get("tts"):
  65. tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
  66. msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
  67. for m in messages if m["role"] != "system"]
  68. if stream:
  69. last_ans = ""
  70. for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
  71. answer = ans
  72. delta_ans = ans[len(last_ans):]
  73. if num_tokens_from_string(delta_ans) < 16:
  74. continue
  75. last_ans = answer
  76. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt":"", "created_at": time.time()}
  77. else:
  78. answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
  79. user_content = msg[-1].get("content", "[content not available]")
  80. logging.debug("User: {}|Assistant: {}".format(user_content, answer))
  81. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
  82. def chat(dialog, messages, stream=True, **kwargs):
  83. assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
  84. if not dialog.kb_ids:
  85. for ans in chat_solo(dialog, messages, stream):
  86. yield ans
  87. return
  88. chat_start_ts = timer()
  89. if llm_id2llm_type(dialog.llm_id) == "image2text":
  90. llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  91. else:
  92. llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  93. max_tokens = llm_model_config.get("max_tokens", 8192)
  94. check_llm_ts = timer()
  95. kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
  96. embedding_list = list(set([kb.embd_id for kb in kbs]))
  97. if len(embedding_list) != 1:
  98. yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
  99. return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
  100. embedding_model_name = embedding_list[0]
  101. retriever = settings.retrievaler
  102. questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
  103. attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
  104. if "doc_ids" in messages[-1]:
  105. attachments = messages[-1]["doc_ids"]
  106. create_retriever_ts = timer()
  107. embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
  108. if not embd_mdl:
  109. raise LookupError("Embedding model(%s) not found" % embedding_model_name)
  110. bind_embedding_ts = timer()
  111. if llm_id2llm_type(dialog.llm_id) == "image2text":
  112. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  113. else:
  114. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  115. bind_llm_ts = timer()
  116. prompt_config = dialog.prompt_config
  117. field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
  118. tts_mdl = None
  119. if prompt_config.get("tts"):
  120. tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
  121. # try to use sql if field mapping is good to go
  122. if field_map:
  123. logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
  124. ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
  125. if ans:
  126. yield ans
  127. return
  128. for p in prompt_config["parameters"]:
  129. if p["key"] == "knowledge":
  130. continue
  131. if p["key"] not in kwargs and not p["optional"]:
  132. raise KeyError("Miss parameter: " + p["key"])
  133. if p["key"] not in kwargs:
  134. prompt_config["system"] = prompt_config["system"].replace(
  135. "{%s}" % p["key"], " ")
  136. if len(questions) > 1 and prompt_config.get("refine_multiturn"):
  137. questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
  138. else:
  139. questions = questions[-1:]
  140. refine_question_ts = timer()
  141. rerank_mdl = None
  142. if dialog.rerank_id:
  143. rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
  144. bind_reranker_ts = timer()
  145. generate_keyword_ts = bind_reranker_ts
  146. thought = ""
  147. kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
  148. if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
  149. knowledges = []
  150. else:
  151. if prompt_config.get("keyword", False):
  152. questions[-1] += keyword_extraction(chat_mdl, questions[-1])
  153. generate_keyword_ts = timer()
  154. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  155. knowledges = []
  156. if prompt_config.get("reasoning", False):
  157. reasoner = DeepResearcher(chat_mdl,
  158. prompt_config,
  159. partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3))
  160. for think in reasoner.thinking(kbinfos, " ".join(questions)):
  161. if isinstance(think, str):
  162. thought = think
  163. knowledges = [t for t in think.split("\n") if t]
  164. else:
  165. yield think
  166. else:
  167. kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
  168. dialog.similarity_threshold,
  169. dialog.vector_similarity_weight,
  170. doc_ids=attachments,
  171. top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
  172. rank_feature=label_question(" ".join(questions), kbs)
  173. )
  174. if prompt_config.get("tavily_api_key"):
  175. tav = Tavily(prompt_config["tavily_api_key"])
  176. tav_res = tav.retrieve_chunks(" ".join(questions))
  177. kbinfos["chunks"].extend(tav_res["chunks"])
  178. kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
  179. if prompt_config.get("use_kg"):
  180. ck = settings.kg_retrievaler.retrieval(" ".join(questions),
  181. tenant_ids,
  182. dialog.kb_ids,
  183. embd_mdl,
  184. LLMBundle(dialog.tenant_id, LLMType.CHAT))
  185. if ck["content_with_weight"]:
  186. kbinfos["chunks"].insert(0, ck)
  187. knowledges = kb_prompt(kbinfos, max_tokens)
  188. logging.debug(
  189. "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
  190. retrieval_ts = timer()
  191. if not knowledges and prompt_config.get("empty_response"):
  192. empty_res = prompt_config["empty_response"]
  193. yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
  194. return {"answer": prompt_config["empty_response"], "reference": kbinfos}
  195. kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
  196. gen_conf = dialog.llm_setting
  197. msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
  198. msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
  199. for m in messages if m["role"] != "system"])
  200. used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
  201. assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
  202. prompt = msg[0]["content"]
  203. prompt += "\n\n### Query:\n%s" % " ".join(questions)
  204. if "max_tokens" in gen_conf:
  205. gen_conf["max_tokens"] = min(
  206. gen_conf["max_tokens"],
  207. max_tokens - used_token_count)
  208. def decorate_answer(answer):
  209. nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts
  210. refs = []
  211. ans = answer.split("</think>")
  212. think = ""
  213. if len(ans) == 2:
  214. think = ans[0] + "</think>"
  215. answer = ans[1]
  216. if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
  217. answer, idx = retriever.insert_citations(answer,
  218. [ck["content_ltks"]
  219. for ck in kbinfos["chunks"]],
  220. [ck["vector"]
  221. for ck in kbinfos["chunks"]],
  222. embd_mdl,
  223. tkweight=1 - dialog.vector_similarity_weight,
  224. vtweight=dialog.vector_similarity_weight)
  225. idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
  226. recall_docs = [
  227. d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
  228. if not recall_docs:
  229. recall_docs = kbinfos["doc_aggs"]
  230. kbinfos["doc_aggs"] = recall_docs
  231. refs = deepcopy(kbinfos)
  232. for c in refs["chunks"]:
  233. if c.get("vector"):
  234. del c["vector"]
  235. if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
  236. answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
  237. finish_chat_ts = timer()
  238. total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
  239. check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
  240. create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
  241. bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
  242. bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
  243. refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
  244. bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
  245. generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
  246. retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
  247. generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
  248. prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
  249. return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
  250. if stream:
  251. last_ans = ""
  252. answer = ""
  253. for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
  254. if thought:
  255. ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
  256. answer = ans
  257. delta_ans = ans[len(last_ans):]
  258. if num_tokens_from_string(delta_ans) < 16:
  259. continue
  260. last_ans = answer
  261. yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
  262. delta_ans = answer[len(last_ans):]
  263. if delta_ans:
  264. yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
  265. yield decorate_answer(thought+answer)
  266. else:
  267. answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
  268. user_content = msg[-1].get("content", "[content not available]")
  269. logging.debug("User: {}|Assistant: {}".format(user_content, answer))
  270. res = decorate_answer(answer)
  271. res["audio_binary"] = tts(tts_mdl, answer)
  272. yield res
  273. def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
  274. sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
  275. user_prompt = """
  276. Table name: {};
  277. Table of database fields are as follows:
  278. {}
  279. Question are as follows:
  280. {}
  281. Please write the SQL, only SQL, without any other explanations or text.
  282. """.format(
  283. index_name(tenant_id),
  284. "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
  285. question
  286. )
  287. tried_times = 0
  288. def get_table():
  289. nonlocal sys_prompt, user_prompt, question, tried_times
  290. sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
  291. "temperature": 0.06})
  292. logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
  293. sql = re.sub(r"[\r\n]+", " ", sql.lower())
  294. sql = re.sub(r".*select ", "select ", sql.lower())
  295. sql = re.sub(r" +", " ", sql)
  296. sql = re.sub(r"([;;]|```).*", "", sql)
  297. if sql[:len("select ")] != "select ":
  298. return None, None
  299. if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
  300. if sql[:len("select *")] != "select *":
  301. sql = "select doc_id,docnm_kwd," + sql[6:]
  302. else:
  303. flds = []
  304. for k in field_map.keys():
  305. if k in forbidden_select_fields4resume:
  306. continue
  307. if len(flds) > 11:
  308. break
  309. flds.append(k)
  310. sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
  311. logging.debug(f"{question} get SQL(refined): {sql}")
  312. tried_times += 1
  313. return settings.retrievaler.sql_retrieval(sql, format="json"), sql
  314. tbl, sql = get_table()
  315. if tbl is None:
  316. return None
  317. if tbl.get("error") and tried_times <= 2:
  318. user_prompt = """
  319. Table name: {};
  320. Table of database fields are as follows:
  321. {}
  322. Question are as follows:
  323. {}
  324. Please write the SQL, only SQL, without any other explanations or text.
  325. The SQL error you provided last time is as follows:
  326. {}
  327. Error issued by database as follows:
  328. {}
  329. Please correct the error and write SQL again, only SQL, without any other explanations or text.
  330. """.format(
  331. index_name(tenant_id),
  332. "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
  333. question, sql, tbl["error"]
  334. )
  335. tbl, sql = get_table()
  336. logging.debug("TRY it again: {}".format(sql))
  337. logging.debug("GET table: {}".format(tbl))
  338. if tbl.get("error") or len(tbl["rows"]) == 0:
  339. return None
  340. docid_idx = set([ii for ii, c in enumerate(
  341. tbl["columns"]) if c["name"] == "doc_id"])
  342. doc_name_idx = set([ii for ii, c in enumerate(
  343. tbl["columns"]) if c["name"] == "docnm_kwd"])
  344. column_idx = [ii for ii in range(
  345. len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
  346. # compose Markdown table
  347. columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
  348. tbl["columns"][i]["name"])) for i in
  349. column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
  350. line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
  351. ("|------|" if docid_idx and docid_idx else "")
  352. rows = ["|" +
  353. "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
  354. "|" for r in tbl["rows"]]
  355. rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
  356. if quota:
  357. rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
  358. else:
  359. rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
  360. rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
  361. if not docid_idx or not doc_name_idx:
  362. logging.warning("SQL missing field: " + sql)
  363. return {
  364. "answer": "\n".join([columns, line, rows]),
  365. "reference": {"chunks": [], "doc_aggs": []},
  366. "prompt": sys_prompt
  367. }
  368. docid_idx = list(docid_idx)[0]
  369. doc_name_idx = list(doc_name_idx)[0]
  370. doc_aggs = {}
  371. for r in tbl["rows"]:
  372. if r[docid_idx] not in doc_aggs:
  373. doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
  374. doc_aggs[r[docid_idx]]["count"] += 1
  375. return {
  376. "answer": "\n".join([columns, line, rows]),
  377. "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
  378. "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
  379. doc_aggs.items()]},
  380. "prompt": sys_prompt
  381. }
  382. def tts(tts_mdl, text):
  383. if not tts_mdl or not text:
  384. return
  385. bin = b""
  386. for chunk in tts_mdl.tts(text):
  387. bin += chunk
  388. return binascii.hexlify(bin).decode("utf-8")
  389. def ask(question, kb_ids, tenant_id):
  390. kbs = KnowledgebaseService.get_by_ids(kb_ids)
  391. embedding_list = list(set([kb.embd_id for kb in kbs]))
  392. is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
  393. retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
  394. embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
  395. chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
  396. max_tokens = chat_mdl.max_length
  397. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  398. kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
  399. 1, 12, 0.1, 0.3, aggs=False,
  400. rank_feature=label_question(question, kbs)
  401. )
  402. knowledges = kb_prompt(kbinfos, max_tokens)
  403. prompt = """
  404. Role: You're a smart assistant. Your name is Miss R.
  405. Task: Summarize the information from knowledge bases and answer user's question.
  406. Requirements and restriction:
  407. - DO NOT make things up, especially for numbers.
  408. - If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
  409. - Answer with markdown format text.
  410. - Answer in language of user's question.
  411. - DO NOT make things up, especially for numbers.
  412. ### Information from knowledge bases
  413. %s
  414. The above is information from knowledge bases.
  415. """ % "\n".join(knowledges)
  416. msg = [{"role": "user", "content": question}]
  417. def decorate_answer(answer):
  418. nonlocal knowledges, kbinfos, prompt
  419. answer, idx = retriever.insert_citations(answer,
  420. [ck["content_ltks"]
  421. for ck in kbinfos["chunks"]],
  422. [ck["vector"]
  423. for ck in kbinfos["chunks"]],
  424. embd_mdl,
  425. tkweight=0.7,
  426. vtweight=0.3)
  427. idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
  428. recall_docs = [
  429. d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
  430. if not recall_docs:
  431. recall_docs = kbinfos["doc_aggs"]
  432. kbinfos["doc_aggs"] = recall_docs
  433. refs = deepcopy(kbinfos)
  434. for c in refs["chunks"]:
  435. if c.get("vector"):
  436. del c["vector"]
  437. if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
  438. answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
  439. return {"answer": answer, "reference": refs}
  440. answer = ""
  441. for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
  442. answer = ans
  443. yield {"answer": answer, "reference": {}}
  444. yield decorate_answer(answer)