您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

dialog_service.py 24KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import binascii
  17. import logging
  18. import re
  19. import time
  20. from copy import deepcopy
  21. from functools import partial
  22. from timeit import default_timer as timer
  23. from langfuse import Langfuse
  24. from agentic_reasoning import DeepResearcher
  25. from api import settings
  26. from api.db import LLMType, ParserType, StatusEnum
  27. from api.db.db_models import DB, Dialog
  28. from api.db.services.common_service import CommonService
  29. from api.db.services.knowledgebase_service import KnowledgebaseService
  30. from api.db.services.langfuse_service import TenantLangfuseService
  31. from api.db.services.llm_service import LLMBundle, TenantLLMService
  32. from rag.app.resume import forbidden_select_fields4resume
  33. from rag.app.tag import label_question
  34. from rag.nlp.search import index_name
  35. from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
  36. from rag.utils import num_tokens_from_string, rmSpace
  37. from rag.utils.tavily_conn import Tavily
  38. class DialogService(CommonService):
  39. model = Dialog
  40. @classmethod
  41. @DB.connection_context()
  42. def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
  43. chats = cls.model.select()
  44. if id:
  45. chats = chats.where(cls.model.id == id)
  46. if name:
  47. chats = chats.where(cls.model.name == name)
  48. chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
  49. if desc:
  50. chats = chats.order_by(cls.model.getter_by(orderby).desc())
  51. else:
  52. chats = chats.order_by(cls.model.getter_by(orderby).asc())
  53. chats = chats.paginate(page_number, items_per_page)
  54. return list(chats.dicts())
  55. def chat_solo(dialog, messages, stream=True):
  56. if llm_id2llm_type(dialog.llm_id) == "image2text":
  57. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  58. else:
  59. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  60. prompt_config = dialog.prompt_config
  61. tts_mdl = None
  62. if prompt_config.get("tts"):
  63. tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
  64. msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
  65. if stream:
  66. last_ans = ""
  67. for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
  68. answer = ans
  69. delta_ans = ans[len(last_ans) :]
  70. if num_tokens_from_string(delta_ans) < 16:
  71. continue
  72. last_ans = answer
  73. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
  74. if delta_ans:
  75. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
  76. else:
  77. answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
  78. user_content = msg[-1].get("content", "[content not available]")
  79. logging.debug("User: {}|Assistant: {}".format(user_content, answer))
  80. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
  81. def chat(dialog, messages, stream=True, **kwargs):
  82. assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
  83. if not dialog.kb_ids:
  84. for ans in chat_solo(dialog, messages, stream):
  85. yield ans
  86. return
  87. chat_start_ts = timer()
  88. if llm_id2llm_type(dialog.llm_id) == "image2text":
  89. llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  90. else:
  91. llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  92. max_tokens = llm_model_config.get("max_tokens", 8192)
  93. check_llm_ts = timer()
  94. langfuse_tracer = None
  95. langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
  96. if langfuse_keys:
  97. langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
  98. if langfuse.auth_check():
  99. langfuse_tracer = langfuse
  100. langfuse.trace = langfuse_tracer.trace(name=f"{dialog.name}-{llm_model_config['llm_name']}")
  101. check_langfuse_tracer_ts = timer()
  102. kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
  103. embedding_list = list(set([kb.embd_id for kb in kbs]))
  104. if len(embedding_list) != 1:
  105. yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
  106. return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
  107. embedding_model_name = embedding_list[0]
  108. retriever = settings.retrievaler
  109. questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
  110. attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
  111. if "doc_ids" in messages[-1]:
  112. attachments = messages[-1]["doc_ids"]
  113. create_retriever_ts = timer()
  114. embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
  115. if not embd_mdl:
  116. raise LookupError("Embedding model(%s) not found" % embedding_model_name)
  117. bind_embedding_ts = timer()
  118. if llm_id2llm_type(dialog.llm_id) == "image2text":
  119. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  120. else:
  121. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  122. bind_llm_ts = timer()
  123. prompt_config = dialog.prompt_config
  124. field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
  125. tts_mdl = None
  126. if prompt_config.get("tts"):
  127. tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
  128. # try to use sql if field mapping is good to go
  129. if field_map:
  130. logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
  131. ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
  132. if ans:
  133. yield ans
  134. return
  135. for p in prompt_config["parameters"]:
  136. if p["key"] == "knowledge":
  137. continue
  138. if p["key"] not in kwargs and not p["optional"]:
  139. raise KeyError("Miss parameter: " + p["key"])
  140. if p["key"] not in kwargs:
  141. prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
  142. if len(questions) > 1 and prompt_config.get("refine_multiturn"):
  143. questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
  144. else:
  145. questions = questions[-1:]
  146. refine_question_ts = timer()
  147. rerank_mdl = None
  148. if dialog.rerank_id:
  149. rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
  150. bind_reranker_ts = timer()
  151. generate_keyword_ts = bind_reranker_ts
  152. thought = ""
  153. kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
  154. if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
  155. knowledges = []
  156. else:
  157. if prompt_config.get("keyword", False):
  158. questions[-1] += keyword_extraction(chat_mdl, questions[-1])
  159. generate_keyword_ts = timer()
  160. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  161. knowledges = []
  162. if prompt_config.get("reasoning", False):
  163. reasoner = DeepResearcher(
  164. chat_mdl,
  165. prompt_config,
  166. partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3),
  167. )
  168. for think in reasoner.thinking(kbinfos, " ".join(questions)):
  169. if isinstance(think, str):
  170. thought = think
  171. knowledges = [t for t in think.split("\n") if t]
  172. elif stream:
  173. yield think
  174. else:
  175. kbinfos = retriever.retrieval(
  176. " ".join(questions),
  177. embd_mdl,
  178. tenant_ids,
  179. dialog.kb_ids,
  180. 1,
  181. dialog.top_n,
  182. dialog.similarity_threshold,
  183. dialog.vector_similarity_weight,
  184. doc_ids=attachments,
  185. top=dialog.top_k,
  186. aggs=False,
  187. rerank_mdl=rerank_mdl,
  188. rank_feature=label_question(" ".join(questions), kbs),
  189. )
  190. if prompt_config.get("tavily_api_key"):
  191. tav = Tavily(prompt_config["tavily_api_key"])
  192. tav_res = tav.retrieve_chunks(" ".join(questions))
  193. kbinfos["chunks"].extend(tav_res["chunks"])
  194. kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
  195. if prompt_config.get("use_kg"):
  196. ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT))
  197. if ck["content_with_weight"]:
  198. kbinfos["chunks"].insert(0, ck)
  199. knowledges = kb_prompt(kbinfos, max_tokens)
  200. logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
  201. retrieval_ts = timer()
  202. if not knowledges and prompt_config.get("empty_response"):
  203. empty_res = prompt_config["empty_response"]
  204. yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
  205. return {"answer": prompt_config["empty_response"], "reference": kbinfos}
  206. kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
  207. gen_conf = dialog.llm_setting
  208. msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
  209. prompt4citation = ""
  210. if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
  211. prompt4citation = citation_prompt()
  212. msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
  213. used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
  214. assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
  215. prompt = msg[0]["content"]
  216. if "max_tokens" in gen_conf:
  217. gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
  218. def decorate_answer(answer):
  219. nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
  220. refs = []
  221. ans = answer.split("</think>")
  222. think = ""
  223. if len(ans) == 2:
  224. think = ans[0] + "</think>"
  225. answer = ans[1]
  226. if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
  227. answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
  228. if not re.search(r"##[0-9]+\$\$", answer):
  229. answer, idx = retriever.insert_citations(
  230. answer,
  231. [ck["content_ltks"] for ck in kbinfos["chunks"]],
  232. [ck["vector"] for ck in kbinfos["chunks"]],
  233. embd_mdl,
  234. tkweight=1 - dialog.vector_similarity_weight,
  235. vtweight=dialog.vector_similarity_weight,
  236. )
  237. else:
  238. idx = set([])
  239. for r in re.finditer(r"##([0-9]+)\$\$", answer):
  240. i = int(r.group(1))
  241. if i < len(kbinfos["chunks"]):
  242. idx.add(i)
  243. idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
  244. recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
  245. if not recall_docs:
  246. recall_docs = kbinfos["doc_aggs"]
  247. kbinfos["doc_aggs"] = recall_docs
  248. refs = deepcopy(kbinfos)
  249. for c in refs["chunks"]:
  250. if c.get("vector"):
  251. del c["vector"]
  252. if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
  253. answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
  254. finish_chat_ts = timer()
  255. total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
  256. check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
  257. check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000
  258. create_retriever_time_cost = (create_retriever_ts - check_langfuse_tracer_ts) * 1000
  259. bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
  260. bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
  261. refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
  262. bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
  263. generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
  264. retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
  265. generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
  266. tk_num = num_tokens_from_string(think + answer)
  267. prompt += "\n\n### Query:\n%s" % " ".join(questions)
  268. prompt = (
  269. f"{prompt}\n\n"
  270. "## Time elapsed:\n"
  271. f" - Total: {total_time_cost:.1f}ms\n"
  272. f" - Check LLM: {check_llm_time_cost:.1f}ms\n"
  273. f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n"
  274. f" - Create retriever: {create_retriever_time_cost:.1f}ms\n"
  275. f" - Bind embedding: {bind_embedding_time_cost:.1f}ms\n"
  276. f" - Bind LLM: {bind_llm_time_cost:.1f}ms\n"
  277. f" - Tune question: {refine_question_time_cost:.1f}ms\n"
  278. f" - Bind reranker: {bind_reranker_time_cost:.1f}ms\n"
  279. f" - Generate keyword: {generate_keyword_time_cost:.1f}ms\n"
  280. f" - Retrieval: {retrieval_time_cost:.1f}ms\n"
  281. f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n"
  282. "## Token usage:\n"
  283. f" - Generated tokens(approximately): {tk_num}\n"
  284. f" - Token speed: {int(tk_num / (generate_result_time_cost / 1000.0))}/s"
  285. )
  286. langfuse_output = "\n" + re.sub(r"^.*?(### Query:.*)", r"\1", prompt, flags=re.DOTALL)
  287. langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()}
  288. langfuse_generation.end(output=langfuse_output)
  289. return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
  290. if langfuse_tracer:
  291. langfuse_generation = langfuse_tracer.trace.generation(name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg})
  292. if stream:
  293. last_ans = ""
  294. answer = ""
  295. for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
  296. if thought:
  297. ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
  298. answer = ans
  299. delta_ans = ans[len(last_ans) :]
  300. if num_tokens_from_string(delta_ans) < 16:
  301. continue
  302. last_ans = answer
  303. yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
  304. delta_ans = answer[len(last_ans) :]
  305. if delta_ans:
  306. yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
  307. yield decorate_answer(thought + answer)
  308. else:
  309. answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
  310. user_content = msg[-1].get("content", "[content not available]")
  311. logging.debug("User: {}|Assistant: {}".format(user_content, answer))
  312. res = decorate_answer(answer)
  313. res["audio_binary"] = tts(tts_mdl, answer)
  314. yield res
  315. def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
  316. sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
  317. user_prompt = """
  318. Table name: {};
  319. Table of database fields are as follows:
  320. {}
  321. Question are as follows:
  322. {}
  323. Please write the SQL, only SQL, without any other explanations or text.
  324. """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
  325. tried_times = 0
  326. def get_table():
  327. nonlocal sys_prompt, user_prompt, question, tried_times
  328. sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
  329. sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL)
  330. logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
  331. sql = re.sub(r"[\r\n]+", " ", sql.lower())
  332. sql = re.sub(r".*select ", "select ", sql.lower())
  333. sql = re.sub(r" +", " ", sql)
  334. sql = re.sub(r"([;;]|```).*", "", sql)
  335. if sql[: len("select ")] != "select ":
  336. return None, None
  337. if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
  338. if sql[: len("select *")] != "select *":
  339. sql = "select doc_id,docnm_kwd," + sql[6:]
  340. else:
  341. flds = []
  342. for k in field_map.keys():
  343. if k in forbidden_select_fields4resume:
  344. continue
  345. if len(flds) > 11:
  346. break
  347. flds.append(k)
  348. sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
  349. logging.debug(f"{question} get SQL(refined): {sql}")
  350. tried_times += 1
  351. return settings.retrievaler.sql_retrieval(sql, format="json"), sql
  352. tbl, sql = get_table()
  353. if tbl is None:
  354. return None
  355. if tbl.get("error") and tried_times <= 2:
  356. user_prompt = """
  357. Table name: {};
  358. Table of database fields are as follows:
  359. {}
  360. Question are as follows:
  361. {}
  362. Please write the SQL, only SQL, without any other explanations or text.
  363. The SQL error you provided last time is as follows:
  364. {}
  365. Error issued by database as follows:
  366. {}
  367. Please correct the error and write SQL again, only SQL, without any other explanations or text.
  368. """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
  369. tbl, sql = get_table()
  370. logging.debug("TRY it again: {}".format(sql))
  371. logging.debug("GET table: {}".format(tbl))
  372. if tbl.get("error") or len(tbl["rows"]) == 0:
  373. return None
  374. docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
  375. doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
  376. column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
  377. # compose Markdown table
  378. columns = (
  379. "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
  380. )
  381. line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
  382. rows = ["|" + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
  383. rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
  384. if quota:
  385. rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
  386. else:
  387. rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
  388. rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
  389. if not docid_idx or not doc_name_idx:
  390. logging.warning("SQL missing field: " + sql)
  391. return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
  392. docid_idx = list(docid_idx)[0]
  393. doc_name_idx = list(doc_name_idx)[0]
  394. doc_aggs = {}
  395. for r in tbl["rows"]:
  396. if r[docid_idx] not in doc_aggs:
  397. doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
  398. doc_aggs[r[docid_idx]]["count"] += 1
  399. return {
  400. "answer": "\n".join([columns, line, rows]),
  401. "reference": {
  402. "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
  403. "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
  404. },
  405. "prompt": sys_prompt,
  406. }
  407. def tts(tts_mdl, text):
  408. if not tts_mdl or not text:
  409. return
  410. bin = b""
  411. for chunk in tts_mdl.tts(text):
  412. bin += chunk
  413. return binascii.hexlify(bin).decode("utf-8")
  414. def ask(question, kb_ids, tenant_id):
  415. kbs = KnowledgebaseService.get_by_ids(kb_ids)
  416. embedding_list = list(set([kb.embd_id for kb in kbs]))
  417. is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
  418. retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
  419. embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
  420. chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
  421. max_tokens = chat_mdl.max_length
  422. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  423. kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
  424. knowledges = kb_prompt(kbinfos, max_tokens)
  425. prompt = """
  426. Role: You're a smart assistant. Your name is Miss R.
  427. Task: Summarize the information from knowledge bases and answer user's question.
  428. Requirements and restriction:
  429. - DO NOT make things up, especially for numbers.
  430. - If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
  431. - Answer with markdown format text.
  432. - Answer in language of user's question.
  433. - DO NOT make things up, especially for numbers.
  434. ### Information from knowledge bases
  435. %s
  436. The above is information from knowledge bases.
  437. """ % "\n".join(knowledges)
  438. msg = [{"role": "user", "content": question}]
  439. def decorate_answer(answer):
  440. nonlocal knowledges, kbinfos, prompt
  441. answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
  442. idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
  443. recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
  444. if not recall_docs:
  445. recall_docs = kbinfos["doc_aggs"]
  446. kbinfos["doc_aggs"] = recall_docs
  447. refs = deepcopy(kbinfos)
  448. for c in refs["chunks"]:
  449. if c.get("vector"):
  450. del c["vector"]
  451. if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
  452. answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
  453. refs["chunks"] = chunks_format(refs)
  454. return {"answer": answer, "reference": refs}
  455. answer = ""
  456. for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
  457. answer = ans
  458. yield {"answer": answer, "reference": {}}
  459. yield decorate_answer(answer)