Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

chat.py 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. from flask import request
  18. from api import settings
  19. from api.db import StatusEnum
  20. from api.db.services.dialog_service import DialogService
  21. from api.db.services.knowledgebase_service import KnowledgebaseService
  22. from api.db.services.llm_service import TenantLLMService
  23. from api.db.services.user_service import TenantService
  24. from api.utils import get_uuid
  25. from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
  26. @manager.route("/chats", methods=["POST"]) # noqa: F821
  27. @token_required
  28. def create(tenant_id):
  29. req = request.json
  30. ids = [i for i in req.get("dataset_ids", []) if i]
  31. for kb_id in ids:
  32. kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
  33. if not kbs:
  34. return get_error_data_result(f"You don't own the dataset {kb_id}")
  35. kbs = KnowledgebaseService.query(id=kb_id)
  36. kb = kbs[0]
  37. if kb.chunk_num == 0:
  38. return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
  39. kbs = KnowledgebaseService.get_by_ids(ids) if ids else []
  40. embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
  41. embd_count = list(set(embd_ids))
  42. if len(embd_count) > 1:
  43. return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR)
  44. req["kb_ids"] = ids
  45. # llm
  46. llm = req.get("llm")
  47. if llm:
  48. if "model_name" in llm:
  49. req["llm_id"] = llm.pop("model_name")
  50. if req.get("llm_id") is not None:
  51. llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"])
  52. if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"):
  53. return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
  54. req["llm_setting"] = req.pop("llm")
  55. e, tenant = TenantService.get_by_id(tenant_id)
  56. if not e:
  57. return get_error_data_result(message="Tenant not found!")
  58. # prompt
  59. prompt = req.get("prompt")
  60. key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"}
  61. key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"]
  62. if prompt:
  63. for new_key, old_key in key_mapping.items():
  64. if old_key in prompt:
  65. prompt[new_key] = prompt.pop(old_key)
  66. for key in key_list:
  67. if key in prompt:
  68. req[key] = prompt.pop(key)
  69. req["prompt_config"] = req.pop("prompt")
  70. # init
  71. req["id"] = get_uuid()
  72. req["description"] = req.get("description", "A helpful Assistant")
  73. req["icon"] = req.get("avatar", "")
  74. req["top_n"] = req.get("top_n", 6)
  75. req["top_k"] = req.get("top_k", 1024)
  76. req["rerank_id"] = req.get("rerank_id", "")
  77. if req.get("rerank_id"):
  78. value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"]
  79. if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"):
  80. return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist")
  81. if not req.get("llm_id"):
  82. req["llm_id"] = tenant.llm_id
  83. if not req.get("name"):
  84. return get_error_data_result(message="`name` is required.")
  85. if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
  86. return get_error_data_result(message="Duplicated chat name in creating chat.")
  87. # tenant_id
  88. if req.get("tenant_id"):
  89. return get_error_data_result(message="`tenant_id` must not be provided.")
  90. req["tenant_id"] = tenant_id
  91. # prompt more parameter
  92. default_prompt = {
  93. "system": """You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.
  94. Here is the knowledge base:
  95. {knowledge}
  96. The above is the knowledge base.""",
  97. "prologue": "Hi! I'm your assistant, what can I do for you?",
  98. "parameters": [{"key": "knowledge", "optional": False}],
  99. "empty_response": "Sorry! No relevant content was found in the knowledge base!",
  100. "quote": True,
  101. "tts": False,
  102. "refine_multiturn": True,
  103. }
  104. key_list_2 = ["system", "prologue", "parameters", "empty_response", "quote", "tts", "refine_multiturn"]
  105. if "prompt_config" not in req:
  106. req["prompt_config"] = {}
  107. for key in key_list_2:
  108. temp = req["prompt_config"].get(key)
  109. if (not temp and key == "system") or (key not in req["prompt_config"]):
  110. req["prompt_config"][key] = default_prompt[key]
  111. for p in req["prompt_config"]["parameters"]:
  112. if p["optional"]:
  113. continue
  114. if req["prompt_config"]["system"].find("{%s}" % p["key"]) < 0:
  115. return get_error_data_result(message="Parameter '{}' is not used".format(p["key"]))
  116. # save
  117. if not DialogService.save(**req):
  118. return get_error_data_result(message="Fail to new a chat!")
  119. # response
  120. e, res = DialogService.get_by_id(req["id"])
  121. if not e:
  122. return get_error_data_result(message="Fail to new a chat!")
  123. res = res.to_json()
  124. renamed_dict = {}
  125. for key, value in res["prompt_config"].items():
  126. new_key = key_mapping.get(key, key)
  127. renamed_dict[new_key] = value
  128. res["prompt"] = renamed_dict
  129. del res["prompt_config"]
  130. new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]}
  131. res["prompt"].update(new_dict)
  132. for key in key_list:
  133. del res[key]
  134. res["llm"] = res.pop("llm_setting")
  135. res["llm"]["model_name"] = res.pop("llm_id")
  136. del res["kb_ids"]
  137. res["dataset_ids"] = req["dataset_ids"]
  138. res["avatar"] = res.pop("icon")
  139. return get_result(data=res)
  140. @manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821
  141. @token_required
  142. def update(tenant_id, chat_id):
  143. if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
  144. return get_error_data_result(message="You do not own the chat")
  145. req = request.json
  146. ids = req.get("dataset_ids")
  147. if "show_quotation" in req:
  148. req["do_refer"] = req.pop("show_quotation")
  149. if ids is not None:
  150. for kb_id in ids:
  151. kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
  152. if not kbs:
  153. return get_error_data_result(f"You don't own the dataset {kb_id}")
  154. kbs = KnowledgebaseService.query(id=kb_id)
  155. kb = kbs[0]
  156. if kb.chunk_num == 0:
  157. return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
  158. kbs = KnowledgebaseService.get_by_ids(ids)
  159. embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
  160. embd_count = list(set(embd_ids))
  161. if len(embd_count) != 1:
  162. return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR)
  163. req["kb_ids"] = ids
  164. llm = req.get("llm")
  165. if llm:
  166. if "model_name" in llm:
  167. req["llm_id"] = llm.pop("model_name")
  168. if not TenantLLMService.query(tenant_id=tenant_id, llm_name=req["llm_id"], model_type="chat"):
  169. return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist")
  170. req["llm_setting"] = req.pop("llm")
  171. e, tenant = TenantService.get_by_id(tenant_id)
  172. if not e:
  173. return get_error_data_result(message="Tenant not found!")
  174. # prompt
  175. prompt = req.get("prompt")
  176. key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"}
  177. key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"]
  178. if prompt:
  179. for new_key, old_key in key_mapping.items():
  180. if old_key in prompt:
  181. prompt[new_key] = prompt.pop(old_key)
  182. for key in key_list:
  183. if key in prompt:
  184. req[key] = prompt.pop(key)
  185. req["prompt_config"] = req.pop("prompt")
  186. e, res = DialogService.get_by_id(chat_id)
  187. res = res.to_json()
  188. if req.get("rerank_id"):
  189. value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"]
  190. if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"):
  191. return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist")
  192. if "name" in req:
  193. if not req.get("name"):
  194. return get_error_data_result(message="`name` cannot be empty.")
  195. if req["name"].lower() != res["name"].lower() and len(DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
  196. return get_error_data_result(message="Duplicated chat name in updating chat.")
  197. if "prompt_config" in req:
  198. res["prompt_config"].update(req["prompt_config"])
  199. for p in res["prompt_config"]["parameters"]:
  200. if p["optional"]:
  201. continue
  202. if res["prompt_config"]["system"].find("{%s}" % p["key"]) < 0:
  203. return get_error_data_result(message="Parameter '{}' is not used".format(p["key"]))
  204. if "llm_setting" in req:
  205. res["llm_setting"].update(req["llm_setting"])
  206. req["prompt_config"] = res["prompt_config"]
  207. req["llm_setting"] = res["llm_setting"]
  208. # avatar
  209. if "avatar" in req:
  210. req["icon"] = req.pop("avatar")
  211. if "dataset_ids" in req:
  212. req.pop("dataset_ids")
  213. if not DialogService.update_by_id(chat_id, req):
  214. return get_error_data_result(message="Chat not found!")
  215. return get_result()
  216. @manager.route("/chats", methods=["DELETE"]) # noqa: F821
  217. @token_required
  218. def delete(tenant_id):
  219. errors = []
  220. success_count = 0
  221. req = request.json
  222. if not req:
  223. ids = None
  224. else:
  225. ids = req.get("ids")
  226. if not ids:
  227. id_list = []
  228. dias = DialogService.query(tenant_id=tenant_id, status=StatusEnum.VALID.value)
  229. for dia in dias:
  230. id_list.append(dia.id)
  231. else:
  232. id_list = ids
  233. unique_id_list, duplicate_messages = check_duplicate_ids(id_list, "assistant")
  234. for id in unique_id_list:
  235. if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value):
  236. errors.append(f"Assistant({id}) not found.")
  237. continue
  238. temp_dict = {"status": StatusEnum.INVALID.value}
  239. DialogService.update_by_id(id, temp_dict)
  240. success_count += 1
  241. if errors:
  242. if success_count > 0:
  243. return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} chats with {len(errors)} errors")
  244. else:
  245. return get_error_data_result(message="; ".join(errors))
  246. if duplicate_messages:
  247. if success_count > 0:
  248. return get_result(message=f"Partially deleted {success_count} chats with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
  249. else:
  250. return get_error_data_result(message=";".join(duplicate_messages))
  251. return get_result()
  252. @manager.route("/chats", methods=["GET"]) # noqa: F821
  253. @token_required
  254. def list_chat(tenant_id):
  255. id = request.args.get("id")
  256. name = request.args.get("name")
  257. if id or name:
  258. chat = DialogService.query(id=id, name=name, status=StatusEnum.VALID.value, tenant_id=tenant_id)
  259. if not chat:
  260. return get_error_data_result(message="The chat doesn't exist")
  261. page_number = int(request.args.get("page", 1))
  262. items_per_page = int(request.args.get("page_size", 30))
  263. orderby = request.args.get("orderby", "create_time")
  264. if request.args.get("desc") == "False" or request.args.get("desc") == "false":
  265. desc = False
  266. else:
  267. desc = True
  268. chats = DialogService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, name)
  269. if not chats:
  270. return get_result(data=[])
  271. list_assts = []
  272. key_mapping = {
  273. "parameters": "variables",
  274. "prologue": "opener",
  275. "quote": "show_quote",
  276. "system": "prompt",
  277. "rerank_id": "rerank_model",
  278. "vector_similarity_weight": "keywords_similarity_weight",
  279. "do_refer": "show_quotation",
  280. }
  281. key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
  282. for res in chats:
  283. renamed_dict = {}
  284. for key, value in res["prompt_config"].items():
  285. new_key = key_mapping.get(key, key)
  286. renamed_dict[new_key] = value
  287. res["prompt"] = renamed_dict
  288. del res["prompt_config"]
  289. new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]}
  290. res["prompt"].update(new_dict)
  291. for key in key_list:
  292. del res[key]
  293. res["llm"] = res.pop("llm_setting")
  294. res["llm"]["model_name"] = res.pop("llm_id")
  295. kb_list = []
  296. for kb_id in res["kb_ids"]:
  297. kb = KnowledgebaseService.query(id=kb_id)
  298. if not kb:
  299. logging.warning(f"The kb {kb_id} does not exist.")
  300. continue
  301. kb_list.append(kb[0].to_json())
  302. del res["kb_ids"]
  303. res["datasets"] = kb_list
  304. res["avatar"] = res.pop("icon")
  305. list_assts.append(res)
  306. return get_result(data=list_assts)