Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

dialog_service.py 23KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import binascii
  18. import time
  19. from functools import partial
  20. import re
  21. from copy import deepcopy
  22. from timeit import default_timer as timer
  23. from agentic_reasoning import DeepResearcher
  24. from api.db import LLMType, ParserType, StatusEnum
  25. from api.db.db_models import Dialog, DB
  26. from api.db.services.common_service import CommonService
  27. from api.db.services.knowledgebase_service import KnowledgebaseService
  28. from api.db.services.llm_service import TenantLLMService, LLMBundle
  29. from api import settings
  30. from rag.app.resume import forbidden_select_fields4resume
  31. from rag.app.tag import label_question
  32. from rag.nlp.search import index_name
  33. from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format
  34. from rag.utils import rmSpace, num_tokens_from_string
  35. from rag.utils.tavily_conn import Tavily
  36. class DialogService(CommonService):
  37. model = Dialog
  38. @classmethod
  39. @DB.connection_context()
  40. def get_list(cls, tenant_id,
  41. page_number, items_per_page, orderby, desc, id, name):
  42. chats = cls.model.select()
  43. if id:
  44. chats = chats.where(cls.model.id == id)
  45. if name:
  46. chats = chats.where(cls.model.name == name)
  47. chats = chats.where(
  48. (cls.model.tenant_id == tenant_id)
  49. & (cls.model.status == StatusEnum.VALID.value)
  50. )
  51. if desc:
  52. chats = chats.order_by(cls.model.getter_by(orderby).desc())
  53. else:
  54. chats = chats.order_by(cls.model.getter_by(orderby).asc())
  55. chats = chats.paginate(page_number, items_per_page)
  56. return list(chats.dicts())
  57. def chat_solo(dialog, messages, stream=True):
  58. if llm_id2llm_type(dialog.llm_id) == "image2text":
  59. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  60. else:
  61. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  62. prompt_config = dialog.prompt_config
  63. tts_mdl = None
  64. if prompt_config.get("tts"):
  65. tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
  66. msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
  67. for m in messages if m["role"] != "system"]
  68. if stream:
  69. last_ans = ""
  70. for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
  71. answer = ans
  72. delta_ans = ans[len(last_ans):]
  73. if num_tokens_from_string(delta_ans) < 16:
  74. continue
  75. last_ans = answer
  76. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
  77. if delta_ans:
  78. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
  79. else:
  80. answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
  81. user_content = msg[-1].get("content", "[content not available]")
  82. logging.debug("User: {}|Assistant: {}".format(user_content, answer))
  83. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
  84. def chat(dialog, messages, stream=True, **kwargs):
  85. assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
  86. if not dialog.kb_ids:
  87. for ans in chat_solo(dialog, messages, stream):
  88. yield ans
  89. return
  90. chat_start_ts = timer()
  91. if llm_id2llm_type(dialog.llm_id) == "image2text":
  92. llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  93. else:
  94. llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  95. max_tokens = llm_model_config.get("max_tokens", 8192)
  96. check_llm_ts = timer()
  97. kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
  98. embedding_list = list(set([kb.embd_id for kb in kbs]))
  99. if len(embedding_list) != 1:
  100. yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
  101. return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
  102. embedding_model_name = embedding_list[0]
  103. retriever = settings.retrievaler
  104. questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
  105. attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
  106. if "doc_ids" in messages[-1]:
  107. attachments = messages[-1]["doc_ids"]
  108. create_retriever_ts = timer()
  109. embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
  110. if not embd_mdl:
  111. raise LookupError("Embedding model(%s) not found" % embedding_model_name)
  112. bind_embedding_ts = timer()
  113. if llm_id2llm_type(dialog.llm_id) == "image2text":
  114. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  115. else:
  116. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  117. bind_llm_ts = timer()
  118. prompt_config = dialog.prompt_config
  119. field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
  120. tts_mdl = None
  121. if prompt_config.get("tts"):
  122. tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
  123. # try to use sql if field mapping is good to go
  124. if field_map:
  125. logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
  126. ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
  127. if ans:
  128. yield ans
  129. return
  130. for p in prompt_config["parameters"]:
  131. if p["key"] == "knowledge":
  132. continue
  133. if p["key"] not in kwargs and not p["optional"]:
  134. raise KeyError("Miss parameter: " + p["key"])
  135. if p["key"] not in kwargs:
  136. prompt_config["system"] = prompt_config["system"].replace(
  137. "{%s}" % p["key"], " ")
  138. if len(questions) > 1 and prompt_config.get("refine_multiturn"):
  139. questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
  140. else:
  141. questions = questions[-1:]
  142. refine_question_ts = timer()
  143. rerank_mdl = None
  144. if dialog.rerank_id:
  145. rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
  146. bind_reranker_ts = timer()
  147. generate_keyword_ts = bind_reranker_ts
  148. thought = ""
  149. kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
  150. if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
  151. knowledges = []
  152. else:
  153. if prompt_config.get("keyword", False):
  154. questions[-1] += keyword_extraction(chat_mdl, questions[-1])
  155. generate_keyword_ts = timer()
  156. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  157. knowledges = []
  158. if prompt_config.get("reasoning", False):
  159. reasoner = DeepResearcher(chat_mdl,
  160. prompt_config,
  161. partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3))
  162. for think in reasoner.thinking(kbinfos, " ".join(questions)):
  163. if isinstance(think, str):
  164. thought = think
  165. knowledges = [t for t in think.split("\n") if t]
  166. elif stream:
  167. yield think
  168. else:
  169. kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
  170. dialog.similarity_threshold,
  171. dialog.vector_similarity_weight,
  172. doc_ids=attachments,
  173. top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
  174. rank_feature=label_question(" ".join(questions), kbs)
  175. )
  176. if prompt_config.get("tavily_api_key"):
  177. tav = Tavily(prompt_config["tavily_api_key"])
  178. tav_res = tav.retrieve_chunks(" ".join(questions))
  179. kbinfos["chunks"].extend(tav_res["chunks"])
  180. kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
  181. if prompt_config.get("use_kg"):
  182. ck = settings.kg_retrievaler.retrieval(" ".join(questions),
  183. tenant_ids,
  184. dialog.kb_ids,
  185. embd_mdl,
  186. LLMBundle(dialog.tenant_id, LLMType.CHAT))
  187. if ck["content_with_weight"]:
  188. kbinfos["chunks"].insert(0, ck)
  189. knowledges = kb_prompt(kbinfos, max_tokens)
  190. logging.debug(
  191. "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
  192. retrieval_ts = timer()
  193. if not knowledges and prompt_config.get("empty_response"):
  194. empty_res = prompt_config["empty_response"]
  195. yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
  196. return {"answer": prompt_config["empty_response"], "reference": kbinfos}
  197. kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
  198. gen_conf = dialog.llm_setting
  199. msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
  200. msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
  201. for m in messages if m["role"] != "system"])
  202. used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
  203. assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
  204. prompt = msg[0]["content"]
  205. if "max_tokens" in gen_conf:
  206. gen_conf["max_tokens"] = min(
  207. gen_conf["max_tokens"],
  208. max_tokens - used_token_count)
  209. def decorate_answer(answer):
  210. nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
  211. refs = []
  212. ans = answer.split("</think>")
  213. think = ""
  214. if len(ans) == 2:
  215. think = ans[0] + "</think>"
  216. answer = ans[1]
  217. if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
  218. answer, idx = retriever.insert_citations(answer,
  219. [ck["content_ltks"]
  220. for ck in kbinfos["chunks"]],
  221. [ck["vector"]
  222. for ck in kbinfos["chunks"]],
  223. embd_mdl,
  224. tkweight=1 - dialog.vector_similarity_weight,
  225. vtweight=dialog.vector_similarity_weight)
  226. idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
  227. recall_docs = [
  228. d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
  229. if not recall_docs:
  230. recall_docs = kbinfos["doc_aggs"]
  231. kbinfos["doc_aggs"] = recall_docs
  232. refs = deepcopy(kbinfos)
  233. for c in refs["chunks"]:
  234. if c.get("vector"):
  235. del c["vector"]
  236. if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
  237. answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
  238. finish_chat_ts = timer()
  239. total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
  240. check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
  241. create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
  242. bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
  243. bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
  244. refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
  245. bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
  246. generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
  247. retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
  248. generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
  249. prompt += "\n\n### Query:\n%s" % " ".join(questions)
  250. prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
  251. return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
  252. if stream:
  253. last_ans = ""
  254. answer = ""
  255. for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
  256. if thought:
  257. ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
  258. answer = ans
  259. delta_ans = ans[len(last_ans):]
  260. if num_tokens_from_string(delta_ans) < 16:
  261. continue
  262. last_ans = answer
  263. yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
  264. delta_ans = answer[len(last_ans):]
  265. if delta_ans:
  266. yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
  267. yield decorate_answer(thought+answer)
  268. else:
  269. answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
  270. user_content = msg[-1].get("content", "[content not available]")
  271. logging.debug("User: {}|Assistant: {}".format(user_content, answer))
  272. res = decorate_answer(answer)
  273. res["audio_binary"] = tts(tts_mdl, answer)
  274. yield res
  275. def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
  276. sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
  277. user_prompt = """
  278. Table name: {};
  279. Table of database fields are as follows:
  280. {}
  281. Question are as follows:
  282. {}
  283. Please write the SQL, only SQL, without any other explanations or text.
  284. """.format(
  285. index_name(tenant_id),
  286. "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
  287. question
  288. )
  289. tried_times = 0
  290. def get_table():
  291. nonlocal sys_prompt, user_prompt, question, tried_times
  292. sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
  293. "temperature": 0.06})
  294. sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL)
  295. logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
  296. sql = re.sub(r"[\r\n]+", " ", sql.lower())
  297. sql = re.sub(r".*select ", "select ", sql.lower())
  298. sql = re.sub(r" +", " ", sql)
  299. sql = re.sub(r"([;;]|```).*", "", sql)
  300. if sql[:len("select ")] != "select ":
  301. return None, None
  302. if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
  303. if sql[:len("select *")] != "select *":
  304. sql = "select doc_id,docnm_kwd," + sql[6:]
  305. else:
  306. flds = []
  307. for k in field_map.keys():
  308. if k in forbidden_select_fields4resume:
  309. continue
  310. if len(flds) > 11:
  311. break
  312. flds.append(k)
  313. sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
  314. logging.debug(f"{question} get SQL(refined): {sql}")
  315. tried_times += 1
  316. return settings.retrievaler.sql_retrieval(sql, format="json"), sql
  317. tbl, sql = get_table()
  318. if tbl is None:
  319. return None
  320. if tbl.get("error") and tried_times <= 2:
  321. user_prompt = """
  322. Table name: {};
  323. Table of database fields are as follows:
  324. {}
  325. Question are as follows:
  326. {}
  327. Please write the SQL, only SQL, without any other explanations or text.
  328. The SQL error you provided last time is as follows:
  329. {}
  330. Error issued by database as follows:
  331. {}
  332. Please correct the error and write SQL again, only SQL, without any other explanations or text.
  333. """.format(
  334. index_name(tenant_id),
  335. "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
  336. question, sql, tbl["error"]
  337. )
  338. tbl, sql = get_table()
  339. logging.debug("TRY it again: {}".format(sql))
  340. logging.debug("GET table: {}".format(tbl))
  341. if tbl.get("error") or len(tbl["rows"]) == 0:
  342. return None
  343. docid_idx = set([ii for ii, c in enumerate(
  344. tbl["columns"]) if c["name"] == "doc_id"])
  345. doc_name_idx = set([ii for ii, c in enumerate(
  346. tbl["columns"]) if c["name"] == "docnm_kwd"])
  347. column_idx = [ii for ii in range(
  348. len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
  349. # compose Markdown table
  350. columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
  351. tbl["columns"][i]["name"])) for i in
  352. column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
  353. line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
  354. ("|------|" if docid_idx and docid_idx else "")
  355. rows = ["|" +
  356. "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
  357. "|" for r in tbl["rows"]]
  358. rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
  359. if quota:
  360. rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
  361. else:
  362. rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
  363. rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
  364. if not docid_idx or not doc_name_idx:
  365. logging.warning("SQL missing field: " + sql)
  366. return {
  367. "answer": "\n".join([columns, line, rows]),
  368. "reference": {"chunks": [], "doc_aggs": []},
  369. "prompt": sys_prompt
  370. }
  371. docid_idx = list(docid_idx)[0]
  372. doc_name_idx = list(doc_name_idx)[0]
  373. doc_aggs = {}
  374. for r in tbl["rows"]:
  375. if r[docid_idx] not in doc_aggs:
  376. doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
  377. doc_aggs[r[docid_idx]]["count"] += 1
  378. return {
  379. "answer": "\n".join([columns, line, rows]),
  380. "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
  381. "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
  382. doc_aggs.items()]},
  383. "prompt": sys_prompt
  384. }
  385. def tts(tts_mdl, text):
  386. if not tts_mdl or not text:
  387. return
  388. bin = b""
  389. for chunk in tts_mdl.tts(text):
  390. bin += chunk
  391. return binascii.hexlify(bin).decode("utf-8")
  392. def ask(question, kb_ids, tenant_id):
  393. kbs = KnowledgebaseService.get_by_ids(kb_ids)
  394. embedding_list = list(set([kb.embd_id for kb in kbs]))
  395. is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
  396. retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
  397. embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
  398. chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
  399. max_tokens = chat_mdl.max_length
  400. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  401. kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
  402. 1, 12, 0.1, 0.3, aggs=False,
  403. rank_feature=label_question(question, kbs)
  404. )
  405. knowledges = kb_prompt(kbinfos, max_tokens)
  406. prompt = """
  407. Role: You're a smart assistant. Your name is Miss R.
  408. Task: Summarize the information from knowledge bases and answer user's question.
  409. Requirements and restriction:
  410. - DO NOT make things up, especially for numbers.
  411. - If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
  412. - Answer with markdown format text.
  413. - Answer in language of user's question.
  414. - DO NOT make things up, especially for numbers.
  415. ### Information from knowledge bases
  416. %s
  417. The above is information from knowledge bases.
  418. """ % "\n".join(knowledges)
  419. msg = [{"role": "user", "content": question}]
  420. def decorate_answer(answer):
  421. nonlocal knowledges, kbinfos, prompt
  422. answer, idx = retriever.insert_citations(answer,
  423. [ck["content_ltks"]
  424. for ck in kbinfos["chunks"]],
  425. [ck["vector"]
  426. for ck in kbinfos["chunks"]],
  427. embd_mdl,
  428. tkweight=0.7,
  429. vtweight=0.3)
  430. idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
  431. recall_docs = [
  432. d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
  433. if not recall_docs:
  434. recall_docs = kbinfos["doc_aggs"]
  435. kbinfos["doc_aggs"] = recall_docs
  436. refs = deepcopy(kbinfos)
  437. for c in refs["chunks"]:
  438. if c.get("vector"):
  439. del c["vector"]
  440. if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
  441. answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
  442. refs["chunks"] = chunks_format(refs)
  443. return {"answer": answer, "reference": refs}
  444. answer = ""
  445. for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
  446. answer = ans
  447. yield {"answer": answer, "reference": {}}
  448. yield decorate_answer(answer)