You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dialog_service.py 34KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import binascii
  17. import logging
  18. import re
  19. import time
  20. from copy import deepcopy
  21. from datetime import datetime
  22. from functools import partial
  23. from timeit import default_timer as timer
  24. import trio
  25. from langfuse import Langfuse
  26. from peewee import fn
  27. from agentic_reasoning import DeepResearcher
  28. from api import settings
  29. from api.db import LLMType, ParserType, StatusEnum
  30. from api.db.db_models import DB, Dialog
  31. from api.db.services.common_service import CommonService
  32. from api.db.services.document_service import DocumentService
  33. from api.db.services.knowledgebase_service import KnowledgebaseService
  34. from api.db.services.langfuse_service import TenantLangfuseService
  35. from api.db.services.llm_service import LLMBundle
  36. from api.db.services.tenant_llm_service import TenantLLMService
  37. from api.utils import current_timestamp, datetime_format
  38. from graphrag.general.mind_map_extractor import MindMapExtractor
  39. from rag.app.resume import forbidden_select_fields4resume
  40. from rag.app.tag import label_question
  41. from rag.nlp.search import index_name
  42. from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
  43. from rag.prompts.prompts import gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
  44. from rag.utils import num_tokens_from_string, rmSpace
  45. from rag.utils.tavily_conn import Tavily
  46. class DialogService(CommonService):
  47. model = Dialog
  48. @classmethod
  49. def save(cls, **kwargs):
  50. """Save a new record to database.
  51. This method creates a new record in the database with the provided field values,
  52. forcing an insert operation rather than an update.
  53. Args:
  54. **kwargs: Record field values as keyword arguments.
  55. Returns:
  56. Model instance: The created record object.
  57. """
  58. sample_obj = cls.model(**kwargs).save(force_insert=True)
  59. return sample_obj
  60. @classmethod
  61. def update_many_by_id(cls, data_list):
  62. """Update multiple records by their IDs.
  63. This method updates multiple records in the database, identified by their IDs.
  64. It automatically updates the update_time and update_date fields for each record.
  65. Args:
  66. data_list (list): List of dictionaries containing record data to update.
  67. Each dictionary must include an 'id' field.
  68. """
  69. with DB.atomic():
  70. for data in data_list:
  71. data["update_time"] = current_timestamp()
  72. data["update_date"] = datetime_format(datetime.now())
  73. cls.model.update(data).where(cls.model.id == data["id"]).execute()
  74. @classmethod
  75. @DB.connection_context()
  76. def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name):
  77. chats = cls.model.select()
  78. if id:
  79. chats = chats.where(cls.model.id == id)
  80. if name:
  81. chats = chats.where(cls.model.name == name)
  82. chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value))
  83. if desc:
  84. chats = chats.order_by(cls.model.getter_by(orderby).desc())
  85. else:
  86. chats = chats.order_by(cls.model.getter_by(orderby).asc())
  87. chats = chats.paginate(page_number, items_per_page)
  88. return list(chats.dicts())
  89. @classmethod
  90. @DB.connection_context()
  91. def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id=None):
  92. from api.db.db_models import User
  93. fields = [
  94. cls.model.id,
  95. cls.model.tenant_id,
  96. cls.model.name,
  97. cls.model.description,
  98. cls.model.language,
  99. cls.model.llm_id,
  100. cls.model.llm_setting,
  101. cls.model.prompt_type,
  102. cls.model.prompt_config,
  103. cls.model.similarity_threshold,
  104. cls.model.vector_similarity_weight,
  105. cls.model.top_n,
  106. cls.model.top_k,
  107. cls.model.do_refer,
  108. cls.model.rerank_id,
  109. cls.model.kb_ids,
  110. cls.model.icon,
  111. cls.model.status,
  112. User.nickname,
  113. User.avatar.alias("tenant_avatar"),
  114. cls.model.update_time,
  115. cls.model.create_time,
  116. ]
  117. if keywords:
  118. dialogs = (
  119. cls.model.select(*fields)
  120. .join(User, on=(cls.model.tenant_id == User.id))
  121. .where(
  122. (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value),
  123. (fn.LOWER(cls.model.name).contains(keywords.lower())),
  124. )
  125. )
  126. else:
  127. dialogs = (
  128. cls.model.select(*fields)
  129. .join(User, on=(cls.model.tenant_id == User.id))
  130. .where(
  131. (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value),
  132. )
  133. )
  134. if parser_id:
  135. dialogs = dialogs.where(cls.model.parser_id == parser_id)
  136. if desc:
  137. dialogs = dialogs.order_by(cls.model.getter_by(orderby).desc())
  138. else:
  139. dialogs = dialogs.order_by(cls.model.getter_by(orderby).asc())
  140. count = dialogs.count()
  141. if page_number and items_per_page:
  142. dialogs = dialogs.paginate(page_number, items_per_page)
  143. return list(dialogs.dicts()), count
  144. def chat_solo(dialog, messages, stream=True):
  145. if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
  146. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  147. else:
  148. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  149. prompt_config = dialog.prompt_config
  150. tts_mdl = None
  151. if prompt_config.get("tts"):
  152. tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
  153. msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]
  154. if stream:
  155. last_ans = ""
  156. delta_ans = ""
  157. for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
  158. if ans == None: continue # SSDESK: 修复RKLLM返回信息没有终止符的问题
  159. answer = ans
  160. delta_ans = ans[len(last_ans) :]
  161. if num_tokens_from_string(delta_ans) < 16:
  162. continue
  163. last_ans = answer
  164. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
  165. delta_ans = ""
  166. if delta_ans:
  167. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
  168. else:
  169. answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
  170. user_content = msg[-1].get("content", "[content not available]")
  171. logging.debug("User: {}|Assistant: {}".format(user_content, answer))
  172. yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
  173. def get_models(dialog):
  174. embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None
  175. kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
  176. embedding_list = list(set([kb.embd_id for kb in kbs]))
  177. if len(embedding_list) > 1:
  178. raise Exception("**ERROR**: Knowledge bases use different embedding models.")
  179. if embedding_list:
  180. embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0])
  181. if not embd_mdl:
  182. raise LookupError("Embedding model(%s) not found" % embedding_list[0])
  183. if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
  184. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  185. else:
  186. chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  187. if dialog.rerank_id:
  188. rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
  189. if dialog.prompt_config.get("tts"):
  190. tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
  191. return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl
  192. BAD_CITATION_PATTERNS = [
  193. re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12)
  194. re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12]
  195. re.compile(r"【\s*ID\s*[: ]*\s*(\d+)\s*】"), # 【ID: 12】
  196. re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12
  197. ]
  198. def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
  199. max_index = len(kbinfos["chunks"])
  200. def safe_add(i):
  201. if 0 <= i < max_index:
  202. idx.add(i)
  203. return True
  204. return False
  205. def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0):
  206. nonlocal answer
  207. def replacement(match):
  208. try:
  209. i = int(match.group(group_index))
  210. if safe_add(i):
  211. return f"[{repl(i)}]"
  212. except Exception:
  213. pass
  214. return match.group(0)
  215. answer = re.sub(pattern, replacement, answer, flags=flags)
  216. for pattern in BAD_CITATION_PATTERNS:
  217. find_and_replace(pattern)
  218. return answer, idx
  219. def meta_filter(metas: dict, filters: list[dict]):
  220. doc_ids = set([])
  221. def filter_out(v2docs, operator, value):
  222. ids = []
  223. for input, docids in v2docs.items():
  224. try:
  225. input = float(input)
  226. value = float(value)
  227. except Exception:
  228. input = str(input)
  229. value = str(value)
  230. for conds in [
  231. (operator == "contains", str(value).lower() in str(input).lower()),
  232. (operator == "not contains", str(value).lower() not in str(input).lower()),
  233. (operator == "start with", str(input).lower().startswith(str(value).lower())),
  234. (operator == "end with", str(input).lower().endswith(str(value).lower())),
  235. (operator == "empty", not input),
  236. (operator == "not empty", input),
  237. (operator == "=", input == value),
  238. (operator == "≠", input != value),
  239. (operator == ">", input > value),
  240. (operator == "<", input < value),
  241. (operator == "≥", input >= value),
  242. (operator == "≤", input <= value),
  243. ]:
  244. try:
  245. if all(conds):
  246. ids.extend(docids)
  247. break
  248. except Exception:
  249. pass
  250. return ids
  251. for k, v2docs in metas.items():
  252. for f in filters:
  253. if k != f["key"]:
  254. continue
  255. ids = filter_out(v2docs, f["op"], f["value"])
  256. if not doc_ids:
  257. doc_ids = set(ids)
  258. else:
  259. doc_ids = doc_ids & set(ids)
  260. if not doc_ids:
  261. return []
  262. return list(doc_ids)
  263. def chat(dialog, messages, stream=True, **kwargs):
  264. assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
  265. if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
  266. for ans in chat_solo(dialog, messages, stream):
  267. yield ans
  268. return
  269. chat_start_ts = timer()
  270. if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
  271. llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
  272. else:
  273. llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
  274. max_tokens = llm_model_config.get("max_tokens", 8192)
  275. check_llm_ts = timer()
  276. langfuse_tracer = None
  277. trace_context = {}
  278. langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
  279. if langfuse_keys:
  280. langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
  281. if langfuse.auth_check():
  282. langfuse_tracer = langfuse
  283. trace_id = langfuse_tracer.create_trace_id()
  284. trace_context = {"trace_id": trace_id}
  285. check_langfuse_tracer_ts = timer()
  286. kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
  287. toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
  288. if toolcall_session and tools:
  289. chat_mdl.bind_tools(toolcall_session, tools)
  290. bind_models_ts = timer()
  291. retriever = settings.retrievaler
  292. questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
  293. attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else []
  294. if "doc_ids" in messages[-1]:
  295. attachments = messages[-1]["doc_ids"]
  296. prompt_config = dialog.prompt_config
  297. field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
  298. # try to use sql if field mapping is good to go
  299. if field_map:
  300. logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
  301. ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
  302. if ans:
  303. yield ans
  304. return
  305. for p in prompt_config["parameters"]:
  306. if p["key"] == "knowledge":
  307. continue
  308. if p["key"] not in kwargs and not p["optional"]:
  309. raise KeyError("Miss parameter: " + p["key"])
  310. if p["key"] not in kwargs:
  311. prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
  312. if len(questions) > 1 and prompt_config.get("refine_multiturn"):
  313. questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
  314. else:
  315. questions = questions[-1:]
  316. if prompt_config.get("cross_languages"):
  317. questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
  318. if dialog.meta_data_filter:
  319. metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
  320. if dialog.meta_data_filter.get("method") == "auto":
  321. filters = gen_meta_filter(chat_mdl, metas, questions[-1])
  322. attachments.extend(meta_filter(metas, filters))
  323. if not attachments:
  324. attachments = None
  325. elif dialog.meta_data_filter.get("method") == "manual":
  326. attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"]))
  327. if not attachments:
  328. attachments = None
  329. if prompt_config.get("keyword", False):
  330. questions[-1] += keyword_extraction(chat_mdl, questions[-1])
  331. refine_question_ts = timer()
  332. thought = ""
  333. kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
  334. knowledges = []
  335. if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
  336. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  337. knowledges = []
  338. if prompt_config.get("reasoning", False):
  339. reasoner = DeepResearcher(
  340. chat_mdl,
  341. prompt_config,
  342. partial(
  343. retriever.retrieval,
  344. embd_mdl=embd_mdl,
  345. tenant_ids=tenant_ids,
  346. kb_ids=dialog.kb_ids,
  347. page=1,
  348. page_size=dialog.top_n,
  349. similarity_threshold=0.2,
  350. vector_similarity_weight=0.3,
  351. doc_ids=attachments,
  352. ),
  353. )
  354. for think in reasoner.thinking(kbinfos, " ".join(questions)):
  355. if isinstance(think, str):
  356. thought = think
  357. knowledges = [t for t in think.split("\n") if t]
  358. elif stream:
  359. yield think
  360. else:
  361. if embd_mdl:
  362. kbinfos = retriever.retrieval(
  363. " ".join(questions),
  364. embd_mdl,
  365. tenant_ids,
  366. dialog.kb_ids,
  367. 1,
  368. dialog.top_n,
  369. dialog.similarity_threshold,
  370. dialog.vector_similarity_weight,
  371. doc_ids=attachments,
  372. top=dialog.top_k,
  373. aggs=False,
  374. rerank_mdl=rerank_mdl,
  375. rank_feature=label_question(" ".join(questions), kbs),
  376. )
  377. if prompt_config.get("tavily_api_key"):
  378. tav = Tavily(prompt_config["tavily_api_key"])
  379. tav_res = tav.retrieve_chunks(" ".join(questions))
  380. kbinfos["chunks"].extend(tav_res["chunks"])
  381. kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
  382. if prompt_config.get("use_kg"):
  383. ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT))
  384. if ck["content_with_weight"]:
  385. kbinfos["chunks"].insert(0, ck)
  386. knowledges = kb_prompt(kbinfos, max_tokens)
  387. logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
  388. retrieval_ts = timer()
  389. if not knowledges and prompt_config.get("empty_response"):
  390. empty_res = prompt_config["empty_response"]
  391. yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
  392. return {"answer": prompt_config["empty_response"], "reference": kbinfos}
  393. kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
  394. gen_conf = dialog.llm_setting
  395. msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
  396. prompt4citation = ""
  397. if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
  398. prompt4citation = citation_prompt()
  399. msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
  400. used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
  401. assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
  402. prompt = msg[0]["content"]
  403. if "max_tokens" in gen_conf:
  404. gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
  405. def decorate_answer(answer):
  406. nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer
  407. refs = []
  408. ans = answer.split("</think>")
  409. think = ""
  410. if len(ans) == 2:
  411. think = ans[0] + "</think>"
  412. answer = ans[1]
  413. if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
  414. idx = set([])
  415. if embd_mdl and not re.search(r"\[ID:([0-9]+)\]", answer):
  416. answer, idx = retriever.insert_citations(
  417. answer,
  418. [ck["content_ltks"] for ck in kbinfos["chunks"]],
  419. [ck["vector"] for ck in kbinfos["chunks"]],
  420. embd_mdl,
  421. tkweight=1 - dialog.vector_similarity_weight,
  422. vtweight=dialog.vector_similarity_weight,
  423. )
  424. else:
  425. for match in re.finditer(r"\[ID:([0-9]+)\]", answer):
  426. i = int(match.group(1))
  427. if i < len(kbinfos["chunks"]):
  428. idx.add(i)
  429. answer, idx = repair_bad_citation_formats(answer, kbinfos, idx)
  430. idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
  431. recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
  432. if not recall_docs:
  433. recall_docs = kbinfos["doc_aggs"]
  434. kbinfos["doc_aggs"] = recall_docs
  435. refs = deepcopy(kbinfos)
  436. for c in refs["chunks"]:
  437. if c.get("vector"):
  438. del c["vector"]
  439. if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
  440. answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
  441. finish_chat_ts = timer()
  442. total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
  443. check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
  444. check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000
  445. bind_embedding_time_cost = (bind_models_ts - check_langfuse_tracer_ts) * 1000
  446. refine_question_time_cost = (refine_question_ts - bind_models_ts) * 1000
  447. retrieval_time_cost = (retrieval_ts - refine_question_ts) * 1000
  448. generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
  449. tk_num = num_tokens_from_string(think + answer)
  450. prompt += "\n\n### Query:\n%s" % " ".join(questions)
  451. prompt = (
  452. f"{prompt}\n\n"
  453. "## Time elapsed:\n"
  454. f" - Total: {total_time_cost:.1f}ms\n"
  455. f" - Check LLM: {check_llm_time_cost:.1f}ms\n"
  456. f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n"
  457. f" - Bind models: {bind_embedding_time_cost:.1f}ms\n"
  458. f" - Query refinement(LLM): {refine_question_time_cost:.1f}ms\n"
  459. f" - Retrieval: {retrieval_time_cost:.1f}ms\n"
  460. f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n"
  461. "## Token usage:\n"
  462. f" - Generated tokens(approximately): {tk_num}\n"
  463. f" - Token speed: {int(tk_num / (generate_result_time_cost / 1000.0))}/s"
  464. )
  465. # Add a condition check to call the end method only if langfuse_tracer exists
  466. if langfuse_tracer and "langfuse_generation" in locals():
  467. langfuse_output = "\n" + re.sub(r"^.*?(### Query:.*)", r"\1", prompt, flags=re.DOTALL)
  468. langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()}
  469. langfuse_generation.update(output=langfuse_output)
  470. langfuse_generation.end()
  471. return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
  472. if langfuse_tracer:
  473. langfuse_generation = langfuse_tracer.start_generation(
  474. trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}
  475. )
  476. if stream:
  477. last_ans = ""
  478. answer = ""
  479. for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
  480. if thought:
  481. ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
  482. answer = ans
  483. delta_ans = ans[len(last_ans) :]
  484. if num_tokens_from_string(delta_ans) < 16:
  485. continue
  486. last_ans = answer
  487. yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
  488. delta_ans = answer[len(last_ans) :]
  489. if delta_ans:
  490. yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
  491. yield decorate_answer(thought + answer)
  492. else:
  493. answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
  494. user_content = msg[-1].get("content", "[content not available]")
  495. logging.debug("User: {}|Assistant: {}".format(user_content, answer))
  496. res = decorate_answer(answer)
  497. res["audio_binary"] = tts(tts_mdl, answer)
  498. yield res
  499. def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
  500. sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
  501. user_prompt = """
  502. Table name: {};
  503. Table of database fields are as follows:
  504. {}
  505. Question are as follows:
  506. {}
  507. Please write the SQL, only SQL, without any other explanations or text.
  508. """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
  509. tried_times = 0
  510. def get_table():
  511. nonlocal sys_prompt, user_prompt, question, tried_times
  512. sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
  513. sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
  514. logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
  515. sql = re.sub(r"[\r\n]+", " ", sql.lower())
  516. sql = re.sub(r".*select ", "select ", sql.lower())
  517. sql = re.sub(r" +", " ", sql)
  518. sql = re.sub(r"([;;]|```).*", "", sql)
  519. if sql[: len("select ")] != "select ":
  520. return None, None
  521. if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
  522. if sql[: len("select *")] != "select *":
  523. sql = "select doc_id,docnm_kwd," + sql[6:]
  524. else:
  525. flds = []
  526. for k in field_map.keys():
  527. if k in forbidden_select_fields4resume:
  528. continue
  529. if len(flds) > 11:
  530. break
  531. flds.append(k)
  532. sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
  533. logging.debug(f"{question} get SQL(refined): {sql}")
  534. tried_times += 1
  535. return settings.retrievaler.sql_retrieval(sql, format="json"), sql
  536. tbl, sql = get_table()
  537. if tbl is None:
  538. return None
  539. if tbl.get("error") and tried_times <= 2:
  540. user_prompt = """
  541. Table name: {};
  542. Table of database fields are as follows:
  543. {}
  544. Question are as follows:
  545. {}
  546. Please write the SQL, only SQL, without any other explanations or text.
  547. The SQL error you provided last time is as follows:
  548. {}
  549. Error issued by database as follows:
  550. {}
  551. Please correct the error and write SQL again, only SQL, without any other explanations or text.
  552. """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"])
  553. tbl, sql = get_table()
  554. logging.debug("TRY it again: {}".format(sql))
  555. logging.debug("GET table: {}".format(tbl))
  556. if tbl.get("error") or len(tbl["rows"]) == 0:
  557. return None
  558. docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
  559. doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
  560. column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
  561. # compose Markdown table
  562. columns = (
  563. "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
  564. )
  565. line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
  566. rows = ["|" + "|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
  567. rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
  568. if quota:
  569. rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
  570. else:
  571. rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
  572. rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
  573. if not docid_idx or not doc_name_idx:
  574. logging.warning("SQL missing field: " + sql)
  575. return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt}
  576. docid_idx = list(docid_idx)[0]
  577. doc_name_idx = list(doc_name_idx)[0]
  578. doc_aggs = {}
  579. for r in tbl["rows"]:
  580. if r[docid_idx] not in doc_aggs:
  581. doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
  582. doc_aggs[r[docid_idx]]["count"] += 1
  583. return {
  584. "answer": "\n".join([columns, line, rows]),
  585. "reference": {
  586. "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
  587. "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()],
  588. },
  589. "prompt": sys_prompt,
  590. }
  591. def tts(tts_mdl, text):
  592. if not tts_mdl or not text:
  593. return
  594. bin = b""
  595. for chunk in tts_mdl.tts(text):
  596. bin += chunk
  597. return binascii.hexlify(bin).decode("utf-8")
  598. def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
  599. doc_ids = search_config.get("doc_ids", [])
  600. rerank_mdl = None
  601. kb_ids = search_config.get("kb_ids", kb_ids)
  602. chat_llm_name = search_config.get("chat_id", chat_llm_name)
  603. rerank_id = search_config.get("rerank_id", "")
  604. meta_data_filter = search_config.get("meta_data_filter")
  605. kbs = KnowledgebaseService.get_by_ids(kb_ids)
  606. embedding_list = list(set([kb.embd_id for kb in kbs]))
  607. is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
  608. retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
  609. embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
  610. chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
  611. if rerank_id:
  612. rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
  613. max_tokens = chat_mdl.max_length
  614. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  615. if meta_data_filter:
  616. metas = DocumentService.get_meta_by_kbs(kb_ids)
  617. if meta_data_filter.get("method") == "auto":
  618. filters = gen_meta_filter(chat_mdl, metas, question)
  619. doc_ids.extend(meta_filter(metas, filters))
  620. if not doc_ids:
  621. doc_ids = None
  622. elif meta_data_filter.get("method") == "manual":
  623. doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
  624. if not doc_ids:
  625. doc_ids = None
  626. kbinfos = retriever.retrieval(
  627. question = question,
  628. embd_mdl=embd_mdl,
  629. tenant_ids=tenant_ids,
  630. kb_ids=kb_ids,
  631. page=1,
  632. page_size=12,
  633. similarity_threshold=search_config.get("similarity_threshold", 0.1),
  634. vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
  635. top=search_config.get("top_k", 1024),
  636. doc_ids=doc_ids,
  637. aggs=False,
  638. rerank_mdl=rerank_mdl,
  639. rank_feature=label_question(question, kbs)
  640. )
  641. knowledges = kb_prompt(kbinfos, max_tokens)
  642. sys_prompt = PROMPT_JINJA_ENV.from_string(ASK_SUMMARY).render(knowledge="\n".join(knowledges))
  643. msg = [{"role": "user", "content": question}]
  644. def decorate_answer(answer):
  645. nonlocal knowledges, kbinfos, sys_prompt
  646. answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
  647. idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
  648. recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
  649. if not recall_docs:
  650. recall_docs = kbinfos["doc_aggs"]
  651. kbinfos["doc_aggs"] = recall_docs
  652. refs = deepcopy(kbinfos)
  653. for c in refs["chunks"]:
  654. if c.get("vector"):
  655. del c["vector"]
  656. if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
  657. answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
  658. refs["chunks"] = chunks_format(refs)
  659. return {"answer": answer, "reference": refs}
  660. answer = ""
  661. for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
  662. answer = ans
  663. yield {"answer": answer, "reference": {}}
  664. yield decorate_answer(answer)
  665. def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
  666. meta_data_filter = search_config.get("meta_data_filter", {})
  667. doc_ids = search_config.get("doc_ids", [])
  668. rerank_id = search_config.get("rerank_id", "")
  669. rerank_mdl = None
  670. kbs = KnowledgebaseService.get_by_ids(kb_ids)
  671. if not kbs:
  672. return {"error": "No KB selected"}
  673. embedding_list = list(set([kb.embd_id for kb in kbs]))
  674. tenant_ids = list(set([kb.tenant_id for kb in kbs]))
  675. embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=embedding_list[0])
  676. chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
  677. if rerank_id:
  678. rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
  679. if meta_data_filter:
  680. metas = DocumentService.get_meta_by_kbs(kb_ids)
  681. if meta_data_filter.get("method") == "auto":
  682. filters = gen_meta_filter(chat_mdl, metas, question)
  683. doc_ids.extend(meta_filter(metas, filters))
  684. if not doc_ids:
  685. doc_ids = None
  686. elif meta_data_filter.get("method") == "manual":
  687. doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
  688. if not doc_ids:
  689. doc_ids = None
  690. ranks = settings.retrievaler.retrieval(
  691. question=question,
  692. embd_mdl=embd_mdl,
  693. tenant_ids=tenant_ids,
  694. kb_ids=kb_ids,
  695. page=1,
  696. page_size=12,
  697. similarity_threshold=search_config.get("similarity_threshold", 0.2),
  698. vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
  699. top=search_config.get("top_k", 1024),
  700. doc_ids=doc_ids,
  701. aggs=False,
  702. rerank_mdl=rerank_mdl,
  703. rank_feature=label_question(question, kbs),
  704. )
  705. mindmap = MindMapExtractor(chat_mdl)
  706. mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
  707. return mind_map.output