Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

conversation_app.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import re
  18. import traceback
  19. from copy import deepcopy
  20. import trio
  21. from flask import Response, request
  22. from flask_login import current_user, login_required
  23. from api import settings
  24. from api.db import LLMType
  25. from api.db.db_models import APIToken
  26. from api.db.services.conversation_service import ConversationService, structure_answer
  27. from api.db.services.dialog_service import DialogService, ask, chat
  28. from api.db.services.knowledgebase_service import KnowledgebaseService
  29. from api.db.services.llm_service import LLMBundle
  30. from api.db.services.tenant_llm_service import TenantLLMService
  31. from api.db.services.user_service import TenantService, UserTenantService
  32. from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
  33. from graphrag.general.mind_map_extractor import MindMapExtractor
  34. from rag.app.tag import label_question
  35. from rag.prompts.prompts import chunks_format
  36. @manager.route("/set", methods=["POST"]) # noqa: F821
  37. @login_required
  38. def set_conversation():
  39. req = request.json
  40. conv_id = req.get("conversation_id")
  41. is_new = req.get("is_new")
  42. name = req.get("name", "New conversation")
  43. req["user_id"] = current_user.id
  44. if len(name) > 255:
  45. name = name[0:255]
  46. del req["is_new"]
  47. if not is_new:
  48. del req["conversation_id"]
  49. try:
  50. if not ConversationService.update_by_id(conv_id, req):
  51. return get_data_error_result(message="Conversation not found!")
  52. e, conv = ConversationService.get_by_id(conv_id)
  53. if not e:
  54. return get_data_error_result(message="Fail to update a conversation!")
  55. conv = conv.to_dict()
  56. return get_json_result(data=conv)
  57. except Exception as e:
  58. return server_error_response(e)
  59. try:
  60. e, dia = DialogService.get_by_id(req["dialog_id"])
  61. if not e:
  62. return get_data_error_result(message="Dialog not found")
  63. conv = {
  64. "id": conv_id,
  65. "dialog_id": req["dialog_id"],
  66. "name": name,
  67. "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],
  68. "user_id": current_user.id,
  69. "reference": [],
  70. }
  71. ConversationService.save(**conv)
  72. return get_json_result(data=conv)
  73. except Exception as e:
  74. return server_error_response(e)
  75. @manager.route("/get", methods=["GET"]) # noqa: F821
  76. @login_required
  77. def get():
  78. conv_id = request.args["conversation_id"]
  79. try:
  80. e, conv = ConversationService.get_by_id(conv_id)
  81. if not e:
  82. return get_data_error_result(message="Conversation not found!")
  83. tenants = UserTenantService.query(user_id=current_user.id)
  84. avatar = None
  85. for tenant in tenants:
  86. dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
  87. if dialog and len(dialog) > 0:
  88. avatar = dialog[0].icon
  89. break
  90. else:
  91. return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
  92. for ref in conv.reference:
  93. if isinstance(ref, list):
  94. continue
  95. ref["chunks"] = chunks_format(ref)
  96. conv = conv.to_dict()
  97. conv["avatar"] = avatar
  98. return get_json_result(data=conv)
  99. except Exception as e:
  100. return server_error_response(e)
  101. @manager.route("/getsse/<dialog_id>", methods=["GET"]) # type: ignore # noqa: F821
  102. def getsse(dialog_id):
  103. token = request.headers.get("Authorization").split()
  104. if len(token) != 2:
  105. return get_data_error_result(message='Authorization is not valid!"')
  106. token = token[1]
  107. objs = APIToken.query(beta=token)
  108. if not objs:
  109. return get_data_error_result(message='Authentication error: API key is invalid!"')
  110. try:
  111. e, conv = DialogService.get_by_id(dialog_id)
  112. if not e:
  113. return get_data_error_result(message="Dialog not found!")
  114. conv = conv.to_dict()
  115. conv["avatar"] = conv["icon"]
  116. del conv["icon"]
  117. return get_json_result(data=conv)
  118. except Exception as e:
  119. return server_error_response(e)
  120. @manager.route("/rm", methods=["POST"]) # noqa: F821
  121. @login_required
  122. def rm():
  123. conv_ids = request.json["conversation_ids"]
  124. try:
  125. for cid in conv_ids:
  126. exist, conv = ConversationService.get_by_id(cid)
  127. if not exist:
  128. return get_data_error_result(message="Conversation not found!")
  129. tenants = UserTenantService.query(user_id=current_user.id)
  130. for tenant in tenants:
  131. if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
  132. break
  133. else:
  134. return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
  135. ConversationService.delete_by_id(cid)
  136. return get_json_result(data=True)
  137. except Exception as e:
  138. return server_error_response(e)
  139. @manager.route("/list", methods=["GET"]) # noqa: F821
  140. @login_required
  141. def list_conversation():
  142. dialog_id = request.args["dialog_id"]
  143. try:
  144. if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
  145. return get_json_result(data=False, message="Only owner of dialog authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
  146. convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True)
  147. convs = [d.to_dict() for d in convs]
  148. return get_json_result(data=convs)
  149. except Exception as e:
  150. return server_error_response(e)
  151. @manager.route("/completion", methods=["POST"]) # noqa: F821
  152. @login_required
  153. @validate_request("conversation_id", "messages")
  154. def completion():
  155. req = request.json
  156. msg = []
  157. for m in req["messages"]:
  158. if m["role"] == "system":
  159. continue
  160. if m["role"] == "assistant" and not msg:
  161. continue
  162. msg.append(m)
  163. message_id = msg[-1].get("id")
  164. chat_model_id = req.get("llm_id", "")
  165. req.pop("llm_id", None)
  166. chat_model_config = {}
  167. for model_config in [
  168. "temperature",
  169. "top_p",
  170. "frequency_penalty",
  171. "presence_penalty",
  172. "max_tokens",
  173. ]:
  174. config = req.get(model_config)
  175. if config:
  176. chat_model_config[model_config] = config
  177. try:
  178. e, conv = ConversationService.get_by_id(req["conversation_id"])
  179. if not e:
  180. return get_data_error_result(message="Conversation not found!")
  181. conv.message = deepcopy(req["messages"])
  182. e, dia = DialogService.get_by_id(conv.dialog_id)
  183. if not e:
  184. return get_data_error_result(message="Dialog not found!")
  185. del req["conversation_id"]
  186. del req["messages"]
  187. if not conv.reference:
  188. conv.reference = []
  189. conv.reference = [r for r in conv.reference if r]
  190. conv.reference.append({"chunks": [], "doc_aggs": []})
  191. if chat_model_id:
  192. if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id):
  193. req.pop("chat_model_id", None)
  194. req.pop("chat_model_config", None)
  195. return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.")
  196. dia.llm_id = chat_model_id
  197. dia.llm_setting = chat_model_config
  198. is_embedded = bool(chat_model_id)
  199. def stream():
  200. nonlocal dia, msg, req, conv
  201. try:
  202. for ans in chat(dia, msg, True, **req):
  203. ans = structure_answer(conv, ans, message_id, conv.id)
  204. yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
  205. if not is_embedded:
  206. ConversationService.update_by_id(conv.id, conv.to_dict())
  207. except Exception as e:
  208. traceback.print_exc()
  209. yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
  210. yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
  211. if req.get("stream", True):
  212. resp = Response(stream(), mimetype="text/event-stream")
  213. resp.headers.add_header("Cache-control", "no-cache")
  214. resp.headers.add_header("Connection", "keep-alive")
  215. resp.headers.add_header("X-Accel-Buffering", "no")
  216. resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
  217. return resp
  218. else:
  219. answer = None
  220. for ans in chat(dia, msg, **req):
  221. answer = structure_answer(conv, ans, message_id, conv.id)
  222. if not is_embedded:
  223. ConversationService.update_by_id(conv.id, conv.to_dict())
  224. break
  225. return get_json_result(data=answer)
  226. except Exception as e:
  227. return server_error_response(e)
  228. @manager.route("/tts", methods=["POST"]) # noqa: F821
  229. @login_required
  230. def tts():
  231. req = request.json
  232. text = req["text"]
  233. tenants = TenantService.get_info_by(current_user.id)
  234. if not tenants:
  235. return get_data_error_result(message="Tenant not found!")
  236. tts_id = tenants[0]["tts_id"]
  237. if not tts_id:
  238. return get_data_error_result(message="No default TTS model is set")
  239. tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
  240. def stream_audio():
  241. try:
  242. for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
  243. for chunk in tts_mdl.tts(txt):
  244. yield chunk
  245. except Exception as e:
  246. yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8")
  247. resp = Response(stream_audio(), mimetype="audio/mpeg")
  248. resp.headers.add_header("Cache-Control", "no-cache")
  249. resp.headers.add_header("Connection", "keep-alive")
  250. resp.headers.add_header("X-Accel-Buffering", "no")
  251. return resp
  252. @manager.route("/delete_msg", methods=["POST"]) # noqa: F821
  253. @login_required
  254. @validate_request("conversation_id", "message_id")
  255. def delete_msg():
  256. req = request.json
  257. e, conv = ConversationService.get_by_id(req["conversation_id"])
  258. if not e:
  259. return get_data_error_result(message="Conversation not found!")
  260. conv = conv.to_dict()
  261. for i, msg in enumerate(conv["message"]):
  262. if req["message_id"] != msg.get("id", ""):
  263. continue
  264. assert conv["message"][i + 1]["id"] == req["message_id"]
  265. conv["message"].pop(i)
  266. conv["message"].pop(i)
  267. conv["reference"].pop(max(0, i // 2 - 1))
  268. break
  269. ConversationService.update_by_id(conv["id"], conv)
  270. return get_json_result(data=conv)
  271. @manager.route("/thumbup", methods=["POST"]) # noqa: F821
  272. @login_required
  273. @validate_request("conversation_id", "message_id")
  274. def thumbup():
  275. req = request.json
  276. e, conv = ConversationService.get_by_id(req["conversation_id"])
  277. if not e:
  278. return get_data_error_result(message="Conversation not found!")
  279. up_down = req.get("thumbup")
  280. feedback = req.get("feedback", "")
  281. conv = conv.to_dict()
  282. for i, msg in enumerate(conv["message"]):
  283. if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
  284. if up_down:
  285. msg["thumbup"] = True
  286. if "feedback" in msg:
  287. del msg["feedback"]
  288. else:
  289. msg["thumbup"] = False
  290. if feedback:
  291. msg["feedback"] = feedback
  292. break
  293. ConversationService.update_by_id(conv["id"], conv)
  294. return get_json_result(data=conv)
  295. @manager.route("/ask", methods=["POST"]) # noqa: F821
  296. @login_required
  297. @validate_request("question", "kb_ids")
  298. def ask_about():
  299. req = request.json
  300. uid = current_user.id
  301. def stream():
  302. nonlocal req, uid
  303. try:
  304. for ans in ask(req["question"], req["kb_ids"], uid):
  305. yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
  306. except Exception as e:
  307. yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
  308. yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
  309. resp = Response(stream(), mimetype="text/event-stream")
  310. resp.headers.add_header("Cache-control", "no-cache")
  311. resp.headers.add_header("Connection", "keep-alive")
  312. resp.headers.add_header("X-Accel-Buffering", "no")
  313. resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
  314. return resp
  315. @manager.route("/mindmap", methods=["POST"]) # noqa: F821
  316. @login_required
  317. @validate_request("question", "kb_ids")
  318. def mindmap():
  319. req = request.json
  320. kb_ids = req["kb_ids"]
  321. e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
  322. if not e:
  323. return get_data_error_result(message="Knowledgebase not found!")
  324. embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
  325. chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
  326. question = req["question"]
  327. ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, 0.3, 0.3, aggs=False, rank_feature=label_question(question, [kb]))
  328. mindmap = MindMapExtractor(chat_mdl)
  329. mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
  330. mind_map = mind_map.output
  331. if "error" in mind_map:
  332. return server_error_response(Exception(mind_map["error"]))
  333. return get_json_result(data=mind_map)
  334. @manager.route("/related_questions", methods=["POST"]) # noqa: F821
  335. @login_required
  336. @validate_request("question")
  337. def related_questions():
  338. req = request.json
  339. question = req["question"]
  340. chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
  341. prompt = """
  342. Role: You are an AI language model assistant tasked with generating 5-10 related questions based on a user’s original query. These questions should help expand the search query scope and improve search relevance.
  343. Instructions:
  344. Input: You are provided with a user’s question.
  345. Output: Generate 5-10 alternative questions that are related to the original user question. These alternatives should help retrieve a broader range of relevant documents from a vector database.
  346. Context: Focus on rephrasing the original question in different ways, making sure the alternative questions are diverse but still connected to the topic of the original query. Do not create overly obscure, irrelevant, or unrelated questions.
  347. Fallback: If you cannot generate any relevant alternatives, do not return any questions.
  348. Guidance:
  349. 1. Each alternative should be unique but still relevant to the original query.
  350. 2. Keep the phrasing clear, concise, and easy to understand.
  351. 3. Avoid overly technical jargon or specialized terms unless directly relevant.
  352. 4. Ensure that each question contributes towards improving search results by broadening the search angle, not narrowing it.
  353. Example:
  354. Original Question: What are the benefits of electric vehicles?
  355. Alternative Questions:
  356. 1. How do electric vehicles impact the environment?
  357. 2. What are the advantages of owning an electric car?
  358. 3. What is the cost-effectiveness of electric vehicles?
  359. 4. How do electric vehicles compare to traditional cars in terms of fuel efficiency?
  360. 5. What are the environmental benefits of switching to electric cars?
  361. 6. How do electric vehicles help reduce carbon emissions?
  362. 7. Why are electric vehicles becoming more popular?
  363. 8. What are the long-term savings of using electric vehicles?
  364. 9. How do electric vehicles contribute to sustainability?
  365. 10. What are the key benefits of electric vehicles for consumers?
  366. Reason:
  367. Rephrasing the original query into multiple alternative questions helps the user explore different aspects of their search topic, improving the quality of search results.
  368. These questions guide the search engine to provide a more comprehensive set of relevant documents.
  369. """
  370. ans = chat_mdl.chat(
  371. prompt,
  372. [
  373. {
  374. "role": "user",
  375. "content": f"""
  376. Keywords: {question}
  377. Related search terms:
  378. """,
  379. }
  380. ],
  381. {"temperature": 0.9},
  382. )
  383. return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])