Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

dialog_service.py 24KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import binascii
  17. import logging
  18. import re
  19. import time
  20. from copy import deepcopy
  21. from functools import partial
  22. from timeit import default_timer as timer
  23. from langfuse import Langfuse
  24. from agentic_reasoning import DeepResearcher
  25. from api import settings
  26. from api.db import LLMType, ParserType, StatusEnum
  27. from api.db.db_models import DB, Dialog
  28. from api.db.services.common_service import CommonService
  29. from api.db.services.knowledgebase_service import KnowledgebaseService
  30. from api.db.services.langfuse_service import TenantLangfuseService
  31. from api.db.services.llm_service import LLMBundle, TenantLLMService
  32. from rag.app.resume import forbidden_select_fields4resume
  33. from rag.app.tag import label_question
  34. from rag.nlp.search import index_name
  35. from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
  36. from rag.utils import num_tokens_from_string, rmSpace
  37. from rag.utils.tavily_conn import Tavily
  38. class DialogService(CommonService):
  39. model = Dialog
  40. @classmethod
  41. @DB.connection_context()
  42. def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
  43. chats = cls.model.select()
  44. if id:
  45. chats = chats.where(cls.model.id == id)
  46. if name:
  47. chats = chats.where(cls.model.name == name)
  48. chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
  49. if desc:
  50. chats = chats.order_by(cls.model.getter_by(orderby).desc())
  51. else:
  52. chats = chats.order_by(cls.model.getter_by(orderby).asc())
  53. chats = chats.paginate(page_number, items_per_page)
  54. return list(chats.dicts())
  55. def chat_solo(dialog, messages, stream=True):
  56. if llm_id2llm_type(dialog.llm_id) == "image2text":
  57. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  58. else:
  59. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  60. prompt_config = dialog.prompt_config
  61. tts_mdl = None
  62. if prompt_config.get("tts"):
  63. tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
  64. msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
  65. if stream:
  66. last_ans = ""
  67. for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
  68. answer = ans
  69. delta_ans = ans[len(last_ans) :]
  70. if num_tokens_from_string(delta_ans) < 16:
  71. continue
  72. last_ans = answer
  73. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
  74. if delta_ans:
  75. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
  76. else:
  77. answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
  78. user_content = msg[-1].get("content", "[content not available]")
  79. logging.debug("User: {}|Assistant: {}".format(user_content, answer))
  80. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
  81. def chat(dialog, messages, stream=True, **kwargs):
  82. assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
  83. if not dialog.kb_ids:
  84. for ans in chat_solo(dialog, messages, stream):
  85. yield ans
  86. return
  87. chat_start_ts = timer()
  88. if llm_id2llm_type(dialog.llm_id) == "image2text":
  89. llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  90. else:
  91. llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  92. max_tokens = llm_model_config.get("max_tokens", 8192)
  93. check_llm_ts = timer()
  94. langfuse_tracer = None
  95. langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
  96. if langfuse_keys:
  97. langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
  98. if langfuse.auth_check():
  99. langfuse_tracer = langfuse
  100. langfuse.trace = langfuse_tracer.trace(name=f"{dialog.name}-{llm_model_config['llm_name']}")
  101. check_langfuse_tracer_ts = timer()
  102. kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
  103. embedding_list = list(set([kb.embd_id for kb in kbs]))
  104. if len(embedding_list) != 1:
  105. yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
  106. return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
  107. embedding_model_name = embedding_list[0]
  108. retriever = settings.retrievaler
  109. questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
  110. attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
  111. if "doc_ids" in messages[-1]:
  112. attachments = messages[-1]["doc_ids"]
  113. create_retriever_ts = timer()
  114. embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
  115. if not embd_mdl:
  116. raise LookupError("Embedding model(%s) not found" % embedding_model_name)
  117. bind_embedding_ts = timer()
  118. if llm_id2llm_type(dialog.llm_id) == "image2text":
  119. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  120. else:
  121. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  122. toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
  123. if toolcall_session and tools:
  124. chat_mdl.bind_tools(toolcall_session, tools)
  125. bind_llm_ts = timer()
  126. prompt_config = dialog.prompt_config
  127. field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
  128. tts_mdl = None
  129. if prompt_config.get("tts"):
  130. tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
  131. # try to use sql if field mapping is good to go
  132. if field_map:
  133. logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
  134. ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
  135. if ans:
  136. yield ans
  137. return
  138. for p in prompt_config["parameters"]:
  139. if p["key"] == "knowledge":
  140. continue
  141. if p["key"] not in kwargs and not p["optional"]:
  142. raise KeyError("Miss parameter: " + p["key"])
  143. if p["key"] not in kwargs:
  144. prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
  145. if len(questions) > 1 and prompt_config.get("refine_multiturn"):
  146. questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
  147. else:
  148. questions = questions[-1:]
  149. refine_question_ts = timer()
  150. rerank_mdl = None
  151. if dialog.rerank_id:
  152. rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
  153. bind_reranker_ts = timer()
  154. generate_keyword_ts = bind_reranker_ts
  155. thought = ""
  156. kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
  157. if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
  158. knowledges = []
  159. else:
  160. if prompt_config.get("keyword", False):
  161. questions[-1] += keyword_extraction(chat_mdl, questions[-1])
  162. generate_keyword_ts = timer()
  163. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  164. knowledges = []
  165. if prompt_config.get("reasoning", False):
  166. reasoner = DeepResearcher(
  167. chat_mdl,
  168. prompt_config,
  169. partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3),
  170. )
  171. for think in reasoner.thinking(kbinfos, " ".join(questions)):
  172. if isinstance(think, str):
  173. thought = think
  174. knowledges = [t for t in think.split("\n") if t]
  175. elif stream:
  176. yield think
  177. else:
  178. kbinfos = retriever.retrieval(
  179. " ".join(questions),
  180. embd_mdl,
  181. tenant_ids,
  182. dialog.kb_ids,
  183. 1,
  184. dialog.top_n,
  185. dialog.similarity_threshold,
  186. dialog.vector_similarity_weight,
  187. doc_ids=attachments,
  188. top=dialog.top_k,
  189. aggs=False,
  190. rerank_mdl=rerank_mdl,
  191. rank_feature=label_question(" ".join(questions), kbs),
  192. )
  193. if prompt_config.get("tavily_api_key"):
  194. tav = Tavily(prompt_config["tavily_api_key"])
  195. tav_res = tav.retrieve_chunks(" ".join(questions))
  196. kbinfos["chunks"].extend(tav_res["chunks"])
  197. kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
  198. if prompt_config.get("use_kg"):
  199. ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT))
  200. if ck["content_with_weight"]:
  201. kbinfos["chunks"].insert(0, ck)
  202. knowledges = kb_prompt(kbinfos, max_tokens)
  203. logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
  204. retrieval_ts = timer()
  205. if not knowledges and prompt_config.get("empty_response"):
  206. empty_res = prompt_config["empty_response"]
  207. yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
  208. return {"answer": prompt_config["empty_response"], "reference": kbinfos}
  209. kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
  210. gen_conf = dialog.llm_setting
  211. msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
  212. prompt4citation = ""
  213. if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
  214. prompt4citation = citation_prompt()
  215. msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
  216. used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
  217. assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
  218. prompt = msg[0]["content"]
  219. if "max_tokens" in gen_conf:
  220. gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
  221. def decorate_answer(answer):
  222. nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer
  223. refs = []
  224. ans = answer.split("</think>")
  225. think = ""
  226. if len(ans) == 2:
  227. think = ans[0] + "</think>"
  228. answer = ans[1]
  229. if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
  230. answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
  231. if not re.search(r"##[0-9]+\$\$", answer):
  232. answer, idx = retriever.insert_citations(
  233. answer,
  234. [ck["content_ltks"] for ck in kbinfos["chunks"]],
  235. [ck["vector"] for ck in kbinfos["chunks"]],
  236. embd_mdl,
  237. tkweight=1 - dialog.vector_similarity_weight,
  238. vtweight=dialog.vector_similarity_weight,
  239. )
  240. else:
  241. idx = set([])
  242. for r in re.finditer(r"##([0-9]+)\$\$", answer):
  243. i = int(r.group(1))
  244. if i < len(kbinfos["chunks"]):
  245. idx.add(i)
  246. idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
  247. recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
  248. if not recall_docs:
  249. recall_docs = kbinfos["doc_aggs"]
  250. kbinfos["doc_aggs"] = recall_docs
  251. refs = deepcopy(kbinfos)
  252. for c in refs["chunks"]:
  253. if c.get("vector"):
  254. del c["vector"]
  255. if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
  256. answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
  257. finish_chat_ts = timer()
  258. total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
  259. check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
  260. check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000
  261. create_retriever_time_cost = (create_retriever_ts - check_langfuse_tracer_ts) * 1000
  262. bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
  263. bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
  264. refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
  265. bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
  266. generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
  267. retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
  268. generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
  269. tk_num = num_tokens_from_string(think + answer)
  270. prompt += "\n\n### Query:\n%s" % " ".join(questions)
  271. prompt = (
  272. f"{prompt}\n\n"
  273. "## Time elapsed:\n"
  274. f" - Total: {total_time_cost:.1f}ms\n"
  275. f" - Check LLM: {check_llm_time_cost:.1f}ms\n"
  276. f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n"
  277. f" - Create retriever: {create_retriever_time_cost:.1f}ms\n"
  278. f" - Bind embedding: {bind_embedding_time_cost:.1f}ms\n"
  279. f" - Bind LLM: {bind_llm_time_cost:.1f}ms\n"
  280. f" - Multi-turn optimization: {refine_question_time_cost:.1f}ms\n"
  281. f" - Bind reranker: {bind_reranker_time_cost:.1f}ms\n"
  282. f" - Generate keyword: {generate_keyword_time_cost:.1f}ms\n"
  283. f" - Retrieval: {retrieval_time_cost:.1f}ms\n"
  284. f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n"
  285. "## Token usage:\n"
  286. f" - Generated tokens(approximately): {tk_num}\n"
  287. f" - Token speed: {int(tk_num / (generate_result_time_cost / 1000.0))}/s"
  288. )
  289. langfuse_output = "\n" + re.sub(r"^.*?(### Query:.*)", r"\1", prompt, flags=re.DOTALL)
  290. langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()}
  291. # Add a condition check to call the end method only if langfuse_tracer exists
  292. if langfuse_tracer and "langfuse_generation" in locals():
  293. langfuse_generation.end(output=langfuse_output)
  294. return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
  295. if langfuse_tracer:
  296. langfuse_generation = langfuse_tracer.trace.generation(name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg})
  297. if stream:
  298. last_ans = ""
  299. answer = ""
  300. for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
  301. if thought:
  302. ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
  303. answer = ans
  304. delta_ans = ans[len(last_ans) :]
  305. if num_tokens_from_string(delta_ans) < 16:
  306. continue
  307. last_ans = answer
  308. yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
  309. delta_ans = answer[len(last_ans) :]
  310. if delta_ans:
  311. yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
  312. yield decorate_answer(thought + answer)
  313. else:
  314. answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
  315. user_content = msg[-1].get("content", "[content not available]")
  316. logging.debug("User: {}|Assistant: {}".format(user_content, answer))
  317. res = decorate_answer(answer)
  318. res["audio_binary"] = tts(tts_mdl, answer)
  319. yield res
  320. def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
  321. sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
  322. user_prompt = """
  323. Table name: {};
  324. Table of database fields are as follows:
  325. {}
  326. Question are as follows:
  327. {}
  328. Please write the SQL, only SQL, without any other explanations or text.
  329. """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
  330. tried_times = 0
  331. def get_table():
  332. nonlocal sys_prompt, user_prompt, question, tried_times
  333. sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
  334. sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL)
  335. logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
  336. sql = re.sub(r"[\r\n]+", " ", sql.lower())
  337. sql = re.sub(r".*select ", "select ", sql.lower())
  338. sql = re.sub(r" +", " ", sql)
  339. sql = re.sub(r"([;;]|```).*", "", sql)
  340. if sql[: len("select ")] != "select ":
  341. return None, None
  342. if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
  343. if sql[: len("select *")] != "select *":
  344. sql = "select doc_id,docnm_kwd," + sql[6:]
  345. else:
  346. flds = []
  347. for k in field_map.keys():
  348. if k in forbidden_select_fields4resume:
  349. continue
  350. if len(flds) > 11:
  351. break
  352. flds.append(k)
  353. sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
  354. logging.debug(f"{question} get SQL(refined): {sql}")
  355. tried_times += 1
  356. return settings.retrievaler.sql_retrieval(sql, format="json"), sql
  357. tbl, sql = get_table()
  358. if tbl is None:
  359. return None
  360. if tbl.get("error") and tried_times <= 2:
  361. user_prompt = """
  362. Table name: {};
  363. Table of database fields are as follows:
  364. {}
  365. Question are as follows:
  366. {}
  367. Please write the SQL, only SQL, without any other explanations or text.
  368. The SQL error you provided last time is as follows:
  369. {}
  370. Error issued by database as follows:
  371. {}
  372. Please correct the error and write SQL again, only SQL, without any other explanations or text.
  373. """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
  374. tbl, sql = get_table()
  375. logging.debug("TRY it again: {}".format(sql))
  376. logging.debug("GET table: {}".format(tbl))
  377. if tbl.get("error") or len(tbl["rows"]) == 0:
  378. return None
  379. docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
  380. doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
  381. column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
  382. # compose Markdown table
  383. columns = (
  384. "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
  385. )
  386. line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
  387. rows = ["|" + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
  388. rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
  389. if quota:
  390. rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
  391. else:
  392. rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
  393. rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
  394. if not docid_idx or not doc_name_idx:
  395. logging.warning("SQL missing field: " + sql)
  396. return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
  397. docid_idx = list(docid_idx)[0]
  398. doc_name_idx = list(doc_name_idx)[0]
  399. doc_aggs = {}
  400. for r in tbl["rows"]:
  401. if r[docid_idx] not in doc_aggs:
  402. doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
  403. doc_aggs[r[docid_idx]]["count"] += 1
  404. return {
  405. "answer": "\n".join([columns, line, rows]),
  406. "reference": {
  407. "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
  408. "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
  409. },
  410. "prompt": sys_prompt,
  411. }
  412. def tts(tts_mdl, text):
  413. if not tts_mdl or not text:
  414. return
  415. bin = b""
  416. for chunk in tts_mdl.tts(text):
  417. bin += chunk
  418. return binascii.hexlify(bin).decode("utf-8")
  419. def ask(question, kb_ids, tenant_id):
  420. kbs = KnowledgebaseService.get_by_ids(kb_ids)
  421. embedding_list = list(set([kb.embd_id for kb in kbs]))
  422. is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
  423. retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
  424. embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
  425. chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
  426. max_tokens = chat_mdl.max_length
  427. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  428. kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
  429. knowledges = kb_prompt(kbinfos, max_tokens)
  430. prompt = """
  431. Role: You're a smart assistant. Your name is Miss R.
  432. Task: Summarize the information from knowledge bases and answer user's question.
  433. Requirements and restriction:
  434. - DO NOT make things up, especially for numbers.
  435. - If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
  436. - Answer with markdown format text.
  437. - Answer in language of user's question.
  438. - DO NOT make things up, especially for numbers.
  439. ### Information from knowledge bases
  440. %s
  441. The above is information from knowledge bases.
  442. """ % "\n".join(knowledges)
  443. msg = [{"role": "user", "content": question}]
  444. def decorate_answer(answer):
  445. nonlocal knowledges, kbinfos, prompt
  446. answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
  447. idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
  448. recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
  449. if not recall_docs:
  450. recall_docs = kbinfos["doc_aggs"]
  451. kbinfos["doc_aggs"] = recall_docs
  452. refs = deepcopy(kbinfos)
  453. for c in refs["chunks"]:
  454. if c.get("vector"):
  455. del c["vector"]
  456. if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
  457. answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
  458. refs["chunks"] = chunks_format(refs)
  459. return {"answer": answer, "reference": refs}
  460. answer = ""
  461. for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
  462. answer = ans
  463. yield {"answer": answer, "reference": {}}
  464. yield decorate_answer(answer)