Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

dialog_service.py 23KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import binascii
  18. import time
  19. from functools import partial
  20. import re
  21. from copy import deepcopy
  22. from timeit import default_timer as timer
  23. from agentic_reasoning import DeepResearcher
  24. from api.db import LLMType, ParserType, StatusEnum
  25. from api.db.db_models import Dialog, DB
  26. from api.db.services.common_service import CommonService
  27. from api.db.services.knowledgebase_service import KnowledgebaseService
  28. from api.db.services.llm_service import TenantLLMService, LLMBundle
  29. from api import settings
  30. from rag.app.resume import forbidden_select_fields4resume
  31. from rag.app.tag import label_question
  32. from rag.nlp.search import index_name
  33. from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, \
  34. citation_prompt
  35. from rag.utils import rmSpace, num_tokens_from_string
  36. from rag.utils.tavily_conn import Tavily
  37. class DialogService(CommonService):
  38. model = Dialog
  39. @classmethod
  40. @DB.connection_context()
  41. def get_list(cls, tenant_id,
  42. page_number, items_per_page, orderby, desc, id, name):
  43. chats = cls.model.select()
  44. if id:
  45. chats = chats.where(cls.model.id == id)
  46. if name:
  47. chats = chats.where(cls.model.name == name)
  48. chats = chats.where(
  49. (cls.model.tenant_id == tenant_id)
  50. & (cls.model.status == StatusEnum.VALID.value)
  51. )
  52. if desc:
  53. chats = chats.order_by(cls.model.getter_by(orderby).desc())
  54. else:
  55. chats = chats.order_by(cls.model.getter_by(orderby).asc())
  56. chats = chats.paginate(page_number, items_per_page)
  57. return list(chats.dicts())
  58. def chat_solo(dialog, messages, stream=True):
  59. if llm_id2llm_type(dialog.llm_id) == "image2text":
  60. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  61. else:
  62. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  63. prompt_config = dialog.prompt_config
  64. tts_mdl = None
  65. if prompt_config.get("tts"):
  66. tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
  67. msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
  68. for m in messages if m["role"] != "system"]
  69. if stream:
  70. last_ans = ""
  71. for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
  72. answer = ans
  73. delta_ans = ans[len(last_ans):]
  74. if num_tokens_from_string(delta_ans) < 16:
  75. continue
  76. last_ans = answer
  77. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
  78. if delta_ans:
  79. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
  80. else:
  81. answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
  82. user_content = msg[-1].get("content", "[content not available]")
  83. logging.debug("User: {}|Assistant: {}".format(user_content, answer))
  84. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
  85. def chat(dialog, messages, stream=True, **kwargs):
  86. assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
  87. if not dialog.kb_ids:
  88. for ans in chat_solo(dialog, messages, stream):
  89. yield ans
  90. return
  91. chat_start_ts = timer()
  92. if llm_id2llm_type(dialog.llm_id) == "image2text":
  93. llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  94. else:
  95. llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  96. max_tokens = llm_model_config.get("max_tokens", 8192)
  97. check_llm_ts = timer()
  98. kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
  99. embedding_list = list(set([kb.embd_id for kb in kbs]))
  100. if len(embedding_list) != 1:
  101. yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
  102. return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
  103. embedding_model_name = embedding_list[0]
  104. retriever = settings.retrievaler
  105. questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
  106. attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
  107. if "doc_ids" in messages[-1]:
  108. attachments = messages[-1]["doc_ids"]
  109. create_retriever_ts = timer()
  110. embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
  111. if not embd_mdl:
  112. raise LookupError("Embedding model(%s) not found" % embedding_model_name)
  113. bind_embedding_ts = timer()
  114. if llm_id2llm_type(dialog.llm_id) == "image2text":
  115. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  116. else:
  117. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  118. bind_llm_ts = timer()
  119. prompt_config = dialog.prompt_config
  120. field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
  121. tts_mdl = None
  122. if prompt_config.get("tts"):
  123. tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
  124. # try to use sql if field mapping is good to go
  125. if field_map:
  126. logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
  127. ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
  128. if ans:
  129. yield ans
  130. return
  131. for p in prompt_config["parameters"]:
  132. if p["key"] == "knowledge":
  133. continue
  134. if p["key"] not in kwargs and not p["optional"]:
  135. raise KeyError("Miss parameter: " + p["key"])
  136. if p["key"] not in kwargs:
  137. prompt_config["system"] = prompt_config["system"].replace(
  138. "{%s}" % p["key"], " ")
  139. if len(questions) > 1 and prompt_config.get("refine_multiturn"):
  140. questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
  141. else:
  142. questions = questions[-1:]
  143. refine_question_ts = timer()
  144. rerank_mdl = None
  145. if dialog.rerank_id:
  146. rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
  147. bind_reranker_ts = timer()
  148. generate_keyword_ts = bind_reranker_ts
  149. thought = ""
  150. kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
  151. if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
  152. knowledges = []
  153. else:
  154. if prompt_config.get("keyword", False):
  155. questions[-1] += keyword_extraction(chat_mdl, questions[-1])
  156. generate_keyword_ts = timer()
  157. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  158. knowledges = []
  159. if prompt_config.get("reasoning", False):
  160. reasoner = DeepResearcher(chat_mdl,
  161. prompt_config,
  162. partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3))
  163. for think in reasoner.thinking(kbinfos, " ".join(questions)):
  164. if isinstance(think, str):
  165. thought = think
  166. knowledges = [t for t in think.split("\n") if t]
  167. elif stream:
  168. yield think
  169. else:
  170. kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
  171. dialog.similarity_threshold,
  172. dialog.vector_similarity_weight,
  173. doc_ids=attachments,
  174. top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
  175. rank_feature=label_question(" ".join(questions), kbs)
  176. )
  177. if prompt_config.get("tavily_api_key"):
  178. tav = Tavily(prompt_config["tavily_api_key"])
  179. tav_res = tav.retrieve_chunks(" ".join(questions))
  180. kbinfos["chunks"].extend(tav_res["chunks"])
  181. kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
  182. if prompt_config.get("use_kg"):
  183. ck = settings.kg_retrievaler.retrieval(" ".join(questions),
  184. tenant_ids,
  185. dialog.kb_ids,
  186. embd_mdl,
  187. LLMBundle(dialog.tenant_id, LLMType.CHAT))
  188. if ck["content_with_weight"]:
  189. kbinfos["chunks"].insert(0, ck)
  190. knowledges = kb_prompt(kbinfos, max_tokens)
  191. logging.debug(
  192. "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
  193. retrieval_ts = timer()
  194. if not knowledges and prompt_config.get("empty_response"):
  195. empty_res = prompt_config["empty_response"]
  196. yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
  197. return {"answer": prompt_config["empty_response"], "reference": kbinfos}
  198. kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
  199. gen_conf = dialog.llm_setting
  200. msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
  201. prompt4citation = ""
  202. if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
  203. prompt4citation = citation_prompt()
  204. msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
  205. for m in messages if m["role"] != "system"])
  206. used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
  207. assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
  208. prompt = msg[0]["content"]
  209. if "max_tokens" in gen_conf:
  210. gen_conf["max_tokens"] = min(
  211. gen_conf["max_tokens"],
  212. max_tokens - used_token_count)
  213. def decorate_answer(answer):
  214. nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
  215. refs = []
  216. ans = answer.split("</think>")
  217. think = ""
  218. if len(ans) == 2:
  219. think = ans[0] + "</think>"
  220. answer = ans[1]
  221. if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
  222. answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
  223. if not re.search(r"##[0-9]+\$\$", answer):
  224. answer, idx = retriever.insert_citations(answer,
  225. [ck["content_ltks"]
  226. for ck in kbinfos["chunks"]],
  227. [ck["vector"]
  228. for ck in kbinfos["chunks"]],
  229. embd_mdl,
  230. tkweight=1 - dialog.vector_similarity_weight,
  231. vtweight=dialog.vector_similarity_weight)
  232. else:
  233. idx = set([])
  234. for r in re.finditer(r"##([0-9]+)\$\$", answer):
  235. i = int(r.group(1))
  236. if i < len(kbinfos["chunks"]):
  237. idx.add(i)
  238. idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
  239. recall_docs = [
  240. d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
  241. if not recall_docs:
  242. recall_docs = kbinfos["doc_aggs"]
  243. kbinfos["doc_aggs"] = recall_docs
  244. refs = deepcopy(kbinfos)
  245. for c in refs["chunks"]:
  246. if c.get("vector"):
  247. del c["vector"]
  248. if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
  249. answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
  250. finish_chat_ts = timer()
  251. total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
  252. check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
  253. create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
  254. bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
  255. bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
  256. refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
  257. bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
  258. generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
  259. retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
  260. generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
  261. prompt += "\n\n### Query:\n%s" % " ".join(questions)
  262. prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
  263. return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
  264. if stream:
  265. last_ans = ""
  266. answer = ""
  267. for ans in chat_mdl.chat_streamly(prompt+prompt4citation, msg[1:], gen_conf):
  268. if thought:
  269. ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
  270. answer = ans
  271. delta_ans = ans[len(last_ans):]
  272. if num_tokens_from_string(delta_ans) < 16:
  273. continue
  274. last_ans = answer
  275. yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
  276. delta_ans = answer[len(last_ans):]
  277. if delta_ans:
  278. yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
  279. yield decorate_answer(thought+answer)
  280. else:
  281. answer = chat_mdl.chat(prompt+prompt4citation, msg[1:], gen_conf)
  282. user_content = msg[-1].get("content", "[content not available]")
  283. logging.debug("User: {}|Assistant: {}".format(user_content, answer))
  284. res = decorate_answer(answer)
  285. res["audio_binary"] = tts(tts_mdl, answer)
  286. yield res
  287. def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
  288. sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
  289. user_prompt = """
  290. Table name: {};
  291. Table of database fields are as follows:
  292. {}
  293. Question are as follows:
  294. {}
  295. Please write the SQL, only SQL, without any other explanations or text.
  296. """.format(
  297. index_name(tenant_id),
  298. "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
  299. question
  300. )
  301. tried_times = 0
  302. def get_table():
  303. nonlocal sys_prompt, user_prompt, question, tried_times
  304. sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
  305. "temperature": 0.06})
  306. sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL)
  307. logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
  308. sql = re.sub(r"[\r\n]+", " ", sql.lower())
  309. sql = re.sub(r".*select ", "select ", sql.lower())
  310. sql = re.sub(r" +", " ", sql)
  311. sql = re.sub(r"([;;]|```).*", "", sql)
  312. if sql[:len("select ")] != "select ":
  313. return None, None
  314. if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
  315. if sql[:len("select *")] != "select *":
  316. sql = "select doc_id,docnm_kwd," + sql[6:]
  317. else:
  318. flds = []
  319. for k in field_map.keys():
  320. if k in forbidden_select_fields4resume:
  321. continue
  322. if len(flds) > 11:
  323. break
  324. flds.append(k)
  325. sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
  326. logging.debug(f"{question} get SQL(refined): {sql}")
  327. tried_times += 1
  328. return settings.retrievaler.sql_retrieval(sql, format="json"), sql
  329. tbl, sql = get_table()
  330. if tbl is None:
  331. return None
  332. if tbl.get("error") and tried_times <= 2:
  333. user_prompt = """
  334. Table name: {};
  335. Table of database fields are as follows:
  336. {}
  337. Question are as follows:
  338. {}
  339. Please write the SQL, only SQL, without any other explanations or text.
  340. The SQL error you provided last time is as follows:
  341. {}
  342. Error issued by database as follows:
  343. {}
  344. Please correct the error and write SQL again, only SQL, without any other explanations or text.
  345. """.format(
  346. index_name(tenant_id),
  347. "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
  348. question, sql, tbl["error"]
  349. )
  350. tbl, sql = get_table()
  351. logging.debug("TRY it again: {}".format(sql))
  352. logging.debug("GET table: {}".format(tbl))
  353. if tbl.get("error") or len(tbl["rows"]) == 0:
  354. return None
  355. docid_idx = set([ii for ii, c in enumerate(
  356. tbl["columns"]) if c["name"] == "doc_id"])
  357. doc_name_idx = set([ii for ii, c in enumerate(
  358. tbl["columns"]) if c["name"] == "docnm_kwd"])
  359. column_idx = [ii for ii in range(
  360. len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
  361. # compose Markdown table
  362. columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
  363. tbl["columns"][i]["name"])) for i in
  364. column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
  365. line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
  366. ("|------|" if docid_idx and docid_idx else "")
  367. rows = ["|" +
  368. "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
  369. "|" for r in tbl["rows"]]
  370. rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
  371. if quota:
  372. rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
  373. else:
  374. rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
  375. rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
  376. if not docid_idx or not doc_name_idx:
  377. logging.warning("SQL missing field: " + sql)
  378. return {
  379. "answer": "\n".join([columns, line, rows]),
  380. "reference": {"chunks": [], "doc_aggs": []},
  381. "prompt": sys_prompt
  382. }
  383. docid_idx = list(docid_idx)[0]
  384. doc_name_idx = list(doc_name_idx)[0]
  385. doc_aggs = {}
  386. for r in tbl["rows"]:
  387. if r[docid_idx] not in doc_aggs:
  388. doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
  389. doc_aggs[r[docid_idx]]["count"] += 1
  390. return {
  391. "answer": "\n".join([columns, line, rows]),
  392. "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
  393. "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
  394. doc_aggs.items()]},
  395. "prompt": sys_prompt
  396. }
  397. def tts(tts_mdl, text):
  398. if not tts_mdl or not text:
  399. return
  400. bin = b""
  401. for chunk in tts_mdl.tts(text):
  402. bin += chunk
  403. return binascii.hexlify(bin).decode("utf-8")
  404. def ask(question, kb_ids, tenant_id):
  405. kbs = KnowledgebaseService.get_by_ids(kb_ids)
  406. embedding_list = list(set([kb.embd_id for kb in kbs]))
  407. is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
  408. retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
  409. embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
  410. chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
  411. max_tokens = chat_mdl.max_length
  412. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  413. kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
  414. 1, 12, 0.1, 0.3, aggs=False,
  415. rank_feature=label_question(question, kbs)
  416. )
  417. knowledges = kb_prompt(kbinfos, max_tokens)
  418. prompt = """
  419. Role: You're a smart assistant. Your name is Miss R.
  420. Task: Summarize the information from knowledge bases and answer user's question.
  421. Requirements and restriction:
  422. - DO NOT make things up, especially for numbers.
  423. - If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
  424. - Answer with markdown format text.
  425. - Answer in language of user's question.
  426. - DO NOT make things up, especially for numbers.
  427. ### Information from knowledge bases
  428. %s
  429. The above is information from knowledge bases.
  430. """ % "\n".join(knowledges)
  431. msg = [{"role": "user", "content": question}]
  432. def decorate_answer(answer):
  433. nonlocal knowledges, kbinfos, prompt
  434. answer, idx = retriever.insert_citations(answer,
  435. [ck["content_ltks"]
  436. for ck in kbinfos["chunks"]],
  437. [ck["vector"]
  438. for ck in kbinfos["chunks"]],
  439. embd_mdl,
  440. tkweight=0.7,
  441. vtweight=0.3)
  442. idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
  443. recall_docs = [
  444. d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
  445. if not recall_docs:
  446. recall_docs = kbinfos["doc_aggs"]
  447. kbinfos["doc_aggs"] = recall_docs
  448. refs = deepcopy(kbinfos)
  449. for c in refs["chunks"]:
  450. if c.get("vector"):
  451. del c["vector"]
  452. if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
  453. answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
  454. refs["chunks"] = chunks_format(refs)
  455. return {"answer": answer, "reference": refs}
  456. answer = ""
  457. for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
  458. answer = ans
  459. yield {"answer": answer, "reference": {}}
  460. yield decorate_answer(answer)