You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

chat.py 12KB


  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from flask import request
  17. from api.db import StatusEnum
  18. from api.db.db_models import TenantLLM
  19. from api.db.services.dialog_service import DialogService
  20. from api.db.services.knowledgebase_service import KnowledgebaseService
  21. from api.db.services.llm_service import LLMService, TenantLLMService
  22. from api.db.services.user_service import TenantService
  23. from api.settings import RetCode
  24. from api.utils import get_uuid
  25. from api.utils.api_utils import get_error_data_result, token_required
  26. from api.utils.api_utils import get_result
  27. @manager.route('/chat', methods=['POST'])
  28. @token_required
  29. def create(tenant_id):
  30. req=request.json
  31. if not req.get("knowledgebases"):
  32. return get_error_data_result(retmsg="knowledgebases are required")
  33. kb_list = []
  34. for kb in req.get("knowledgebases"):
  35. if not kb["id"]:
  36. return get_error_data_result(retmsg="knowledgebase needs id")
  37. if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id):
  38. return get_error_data_result(retmsg="you do not own the knowledgebase")
  39. # if not DocumentService.query(kb_id=kb["id"]):
  40. # return get_error_data_result(retmsg="There is a invalid knowledgebase")
  41. kb_list.append(kb["id"])
  42. req["kb_ids"] = kb_list
  43. # llm
  44. llm = req.get("llm")
  45. if llm:
  46. if "model_name" in llm:
  47. req["llm_id"] = llm.pop("model_name")
  48. req["llm_setting"] = req.pop("llm")
  49. e, tenant = TenantService.get_by_id(tenant_id)
  50. if not e:
  51. return get_error_data_result(retmsg="Tenant not found!")
  52. # prompt
  53. prompt = req.get("prompt")
  54. key_mapping = {"parameters": "variables",
  55. "prologue": "opener",
  56. "quote": "show_quote",
  57. "system": "prompt",
  58. "rerank_id": "rerank_model",
  59. "vector_similarity_weight": "keywords_similarity_weight"}
  60. key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
  61. if prompt:
  62. for new_key, old_key in key_mapping.items():
  63. if old_key in prompt:
  64. prompt[new_key] = prompt.pop(old_key)
  65. for key in key_list:
  66. if key in prompt:
  67. req[key] = prompt.pop(key)
  68. req["prompt_config"] = req.pop("prompt")
  69. # init
  70. req["id"] = get_uuid()
  71. req["description"] = req.get("description", "A helpful Assistant")
  72. req["icon"] = req.get("avatar", "")
  73. req["top_n"] = req.get("top_n", 6)
  74. req["top_k"] = req.get("top_k", 1024)
  75. req["rerank_id"] = req.get("rerank_id", "")
  76. if req.get("llm_id"):
  77. if not TenantLLMService.query(llm_name=req["llm_id"]):
  78. return get_error_data_result(retmsg="the model_name does not exist.")
  79. else:
  80. req["llm_id"] = tenant.llm_id
  81. if not req.get("name"):
  82. return get_error_data_result(retmsg="name is required.")
  83. if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
  84. return get_error_data_result(retmsg="Duplicated chat name in creating dataset.")
  85. # tenant_id
  86. if req.get("tenant_id"):
  87. return get_error_data_result(retmsg="tenant_id must not be provided.")
  88. req["tenant_id"] = tenant_id
  89. # prompt more parameter
  90. default_prompt = {
  91. "system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。
  92. 以下是知识库:
  93. {knowledge}
  94. 以上是知识库。""",
  95. "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?",
  96. "parameters": [
  97. {"key": "knowledge", "optional": False}
  98. ],
  99. "empty_response": "Sorry! 知识库中未找到相关内容!"
  100. }
  101. key_list_2 = ["system", "prologue", "parameters", "empty_response"]
  102. if "prompt_config" not in req:
  103. req['prompt_config'] = {}
  104. for key in key_list_2:
  105. temp = req['prompt_config'].get(key)
  106. if not temp:
  107. req['prompt_config'][key] = default_prompt[key]
  108. for p in req['prompt_config']["parameters"]:
  109. if p["optional"]:
  110. continue
  111. if req['prompt_config']["system"].find("{%s}" % p["key"]) < 0:
  112. return get_error_data_result(
  113. retmsg="Parameter '{}' is not used".format(p["key"]))
  114. # save
  115. if not DialogService.save(**req):
  116. return get_error_data_result(retmsg="Fail to new a chat!")
  117. # response
  118. e, res = DialogService.get_by_id(req["id"])
  119. if not e:
  120. return get_error_data_result(retmsg="Fail to new a chat!")
  121. res = res.to_json()
  122. renamed_dict = {}
  123. for key, value in res["prompt_config"].items():
  124. new_key = key_mapping.get(key, key)
  125. renamed_dict[new_key] = value
  126. res["prompt"] = renamed_dict
  127. del res["prompt_config"]
  128. new_dict = {"similarity_threshold": res["similarity_threshold"],
  129. "keywords_similarity_weight": res["vector_similarity_weight"],
  130. "top_n": res["top_n"],
  131. "rerank_model": res['rerank_id']}
  132. res["prompt"].update(new_dict)
  133. for key in key_list:
  134. del res[key]
  135. res["llm"] = res.pop("llm_setting")
  136. res["llm"]["model_name"] = res.pop("llm_id")
  137. del res["kb_ids"]
  138. res["knowledgebases"] = req["knowledgebases"]
  139. res["avatar"] = res.pop("icon")
  140. return get_result(data=res)
  141. @manager.route('/chat/<chat_id>', methods=['PUT'])
  142. @token_required
  143. def update(tenant_id,chat_id):
  144. if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
  145. return get_error_data_result(retmsg='You do not own the chat')
  146. req =request.json
  147. if "knowledgebases" in req:
  148. if not req.get("knowledgebases"):
  149. return get_error_data_result(retmsg="knowledgebases can't be empty value")
  150. kb_list = []
  151. for kb in req.get("knowledgebases"):
  152. if not kb["id"]:
  153. return get_error_data_result(retmsg="knowledgebase needs id")
  154. if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id):
  155. return get_error_data_result(retmsg="you do not own the knowledgebase")
  156. # if not DocumentService.query(kb_id=kb["id"]):
  157. # return get_error_data_result(retmsg="There is a invalid knowledgebase")
  158. kb_list.append(kb["id"])
  159. req["kb_ids"] = kb_list
  160. llm = req.get("llm")
  161. if llm:
  162. if "model_name" in llm:
  163. req["llm_id"] = llm.pop("model_name")
  164. req["llm_setting"] = req.pop("llm")
  165. e, tenant = TenantService.get_by_id(tenant_id)
  166. if not e:
  167. return get_error_data_result(retmsg="Tenant not found!")
  168. # prompt
  169. prompt = req.get("prompt")
  170. key_mapping = {"parameters": "variables",
  171. "prologue": "opener",
  172. "quote": "show_quote",
  173. "system": "prompt",
  174. "rerank_id": "rerank_model",
  175. "vector_similarity_weight": "keywords_similarity_weight"}
  176. key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
  177. if prompt:
  178. for new_key, old_key in key_mapping.items():
  179. if old_key in prompt:
  180. prompt[new_key] = prompt.pop(old_key)
  181. for key in key_list:
  182. if key in prompt:
  183. req[key] = prompt.pop(key)
  184. req["prompt_config"] = req.pop("prompt")
  185. e, res = DialogService.get_by_id(chat_id)
  186. res = res.to_json()
  187. if "llm_id" in req:
  188. if not TenantLLMService.query(llm_name=req["llm_id"]):
  189. return get_error_data_result(retmsg="the model_name does not exist.")
  190. if "name" in req:
  191. if not req.get("name"):
  192. return get_error_data_result(retmsg="name is not empty.")
  193. if req["name"].lower() != res["name"].lower() \
  194. and len(
  195. DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
  196. return get_error_data_result(retmsg="Duplicated chat name in updating dataset.")
  197. if "prompt_config" in req:
  198. res["prompt_config"].update(req["prompt_config"])
  199. for p in res["prompt_config"]["parameters"]:
  200. if p["optional"]:
  201. continue
  202. if res["prompt_config"]["system"].find("{%s}" % p["key"]) < 0:
  203. return get_error_data_result(retmsg="Parameter '{}' is not used".format(p["key"]))
  204. if "llm_setting" in req:
  205. res["llm_setting"].update(req["llm_setting"])
  206. req["prompt_config"] = res["prompt_config"]
  207. req["llm_setting"] = res["llm_setting"]
  208. # avatar
  209. if "avatar" in req:
  210. req["icon"] = req.pop("avatar")
  211. if "knowledgebases" in req:
  212. req.pop("knowledgebases")
  213. if not DialogService.update_by_id(chat_id, req):
  214. return get_error_data_result(retmsg="Chat not found!")
  215. return get_result()
  216. @manager.route('/chat', methods=['DELETE'])
  217. @token_required
  218. def delete(tenant_id):
  219. req = request.json
  220. ids = req.get("ids")
  221. if not ids:
  222. return get_error_data_result(retmsg="ids are required")
  223. for id in ids:
  224. if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value):
  225. return get_error_data_result(retmsg=f"You don't own the chat {id}")
  226. temp_dict = {"status": StatusEnum.INVALID.value}
  227. DialogService.update_by_id(id, temp_dict)
  228. return get_result()
  229. @manager.route('/chat', methods=['GET'])
  230. @token_required
  231. def list(tenant_id):
  232. id = request.args.get("id")
  233. name = request.args.get("name")
  234. chat = DialogService.query(id=id,name=name,status=StatusEnum.VALID.value)
  235. if not chat:
  236. return get_error_data_result(retmsg="The chat doesn't exist")
  237. page_number = int(request.args.get("page", 1))
  238. items_per_page = int(request.args.get("page_size", 1024))
  239. orderby = request.args.get("orderby", "create_time")
  240. if request.args.get("desc") == "False" or request.args.get("desc") == "false":
  241. desc = False
  242. else:
  243. desc = True
  244. chats = DialogService.get_list(tenant_id,page_number,items_per_page,orderby,desc,id,name)
  245. if not chats:
  246. return get_result(data=[])
  247. list_assts = []
  248. renamed_dict = {}
  249. key_mapping = {"parameters": "variables",
  250. "prologue": "opener",
  251. "quote": "show_quote",
  252. "system": "prompt",
  253. "rerank_id": "rerank_model",
  254. "vector_similarity_weight": "keywords_similarity_weight"}
  255. key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
  256. for res in chats:
  257. for key, value in res["prompt_config"].items():
  258. new_key = key_mapping.get(key, key)
  259. renamed_dict[new_key] = value
  260. res["prompt"] = renamed_dict
  261. del res["prompt_config"]
  262. new_dict = {"similarity_threshold": res["similarity_threshold"],
  263. "keywords_similarity_weight": res["vector_similarity_weight"],
  264. "top_n": res["top_n"],
  265. "rerank_model": res['rerank_id']}
  266. res["prompt"].update(new_dict)
  267. for key in key_list:
  268. del res[key]
  269. res["llm"] = res.pop("llm_setting")
  270. res["llm"]["model_name"] = res.pop("llm_id")
  271. kb_list = []
  272. for kb_id in res["kb_ids"]:
  273. kb = KnowledgebaseService.query(id=kb_id)
  274. if not kb :
  275. return get_error_data_result(retmsg=f"Don't exist the kb {kb_id}")
  276. kb_list.append(kb[0].to_json())
  277. del res["kb_ids"]
  278. res["knowledgebases"] = kb_list
  279. res["avatar"] = res.pop("icon")
  280. list_assts.append(res)
  281. return get_result(data=list_assts)