Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from flask import request
  17. from api.db import StatusEnum
  18. from api.db.services.dialog_service import DialogService
  19. from api.db.services.knowledgebase_service import KnowledgebaseService
  20. from api.db.services.llm_service import TenantLLMService
  21. from api.db.services.user_service import TenantService
  22. from api.utils import get_uuid
  23. from api.utils.api_utils import get_error_data_result, token_required
  24. from api.utils.api_utils import get_result
  25. @manager.route('/chat', methods=['POST'])
  26. @token_required
  27. def create(tenant_id):
  28. req=request.json
  29. ids= req.get("knowledgebases")
  30. if not ids:
  31. return get_error_data_result(retmsg="`knowledgebases` is required")
  32. for kb_id in ids:
  33. kbs = KnowledgebaseService.query(id=kb_id,tenant_id=tenant_id)
  34. if not kbs:
  35. return get_error_data_result(f"You don't own the dataset {kb_id}")
  36. kb=kbs[0]
  37. if kb.chunk_num == 0:
  38. return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
  39. req["kb_ids"] = ids
  40. # llm
  41. llm = req.get("llm")
  42. if llm:
  43. if "model_name" in llm:
  44. req["llm_id"] = llm.pop("model_name")
  45. req["llm_setting"] = req.pop("llm")
  46. e, tenant = TenantService.get_by_id(tenant_id)
  47. if not e:
  48. return get_error_data_result(retmsg="Tenant not found!")
  49. # prompt
  50. prompt = req.get("prompt")
  51. key_mapping = {"parameters": "variables",
  52. "prologue": "opener",
  53. "quote": "show_quote",
  54. "system": "prompt",
  55. "rerank_id": "rerank_model",
  56. "vector_similarity_weight": "keywords_similarity_weight"}
  57. key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
  58. if prompt:
  59. for new_key, old_key in key_mapping.items():
  60. if old_key in prompt:
  61. prompt[new_key] = prompt.pop(old_key)
  62. for key in key_list:
  63. if key in prompt:
  64. req[key] = prompt.pop(key)
  65. req["prompt_config"] = req.pop("prompt")
  66. # init
  67. req["id"] = get_uuid()
  68. req["description"] = req.get("description", "A helpful Assistant")
  69. req["icon"] = req.get("avatar", "")
  70. req["top_n"] = req.get("top_n", 6)
  71. req["top_k"] = req.get("top_k", 1024)
  72. req["rerank_id"] = req.get("rerank_id", "")
  73. if req.get("llm_id"):
  74. if not TenantLLMService.query(llm_name=req["llm_id"]):
  75. return get_error_data_result(retmsg="the model_name does not exist.")
  76. else:
  77. req["llm_id"] = tenant.llm_id
  78. if not req.get("name"):
  79. return get_error_data_result(retmsg="`name` is required.")
  80. if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value):
  81. return get_error_data_result(retmsg="Duplicated chat name in creating chat.")
  82. # tenant_id
  83. if req.get("tenant_id"):
  84. return get_error_data_result(retmsg="`tenant_id` must not be provided.")
  85. req["tenant_id"] = tenant_id
  86. # prompt more parameter
  87. default_prompt = {
  88. "system": """You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history.
  89. Here is the knowledge base:
  90. {knowledge}
  91. The above is the knowledge base.""",
  92. "prologue": "Hi! I'm your assistant, what can I do for you?",
  93. "parameters": [
  94. {"key": "knowledge", "optional": False}
  95. ],
  96. "empty_response": "Sorry! No relevant content was found in the knowledge base!"
  97. }
  98. key_list_2 = ["system", "prologue", "parameters", "empty_response"]
  99. if "prompt_config" not in req:
  100. req['prompt_config'] = {}
  101. for key in key_list_2:
  102. temp = req['prompt_config'].get(key)
  103. if not temp:
  104. req['prompt_config'][key] = default_prompt[key]
  105. for p in req['prompt_config']["parameters"]:
  106. if p["optional"]:
  107. continue
  108. if req['prompt_config']["system"].find("{%s}" % p["key"]) < 0:
  109. return get_error_data_result(
  110. retmsg="Parameter '{}' is not used".format(p["key"]))
  111. # save
  112. if not DialogService.save(**req):
  113. return get_error_data_result(retmsg="Fail to new a chat!")
  114. # response
  115. e, res = DialogService.get_by_id(req["id"])
  116. if not e:
  117. return get_error_data_result(retmsg="Fail to new a chat!")
  118. res = res.to_json()
  119. renamed_dict = {}
  120. for key, value in res["prompt_config"].items():
  121. new_key = key_mapping.get(key, key)
  122. renamed_dict[new_key] = value
  123. res["prompt"] = renamed_dict
  124. del res["prompt_config"]
  125. new_dict = {"similarity_threshold": res["similarity_threshold"],
  126. "keywords_similarity_weight": res["vector_similarity_weight"],
  127. "top_n": res["top_n"],
  128. "rerank_model": res['rerank_id']}
  129. res["prompt"].update(new_dict)
  130. for key in key_list:
  131. del res[key]
  132. res["llm"] = res.pop("llm_setting")
  133. res["llm"]["model_name"] = res.pop("llm_id")
  134. del res["kb_ids"]
  135. res["knowledgebases"] = req["knowledgebases"]
  136. res["avatar"] = res.pop("icon")
  137. return get_result(data=res)
  138. @manager.route('/chat/<chat_id>', methods=['PUT'])
  139. @token_required
  140. def update(tenant_id,chat_id):
  141. if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
  142. return get_error_data_result(retmsg='You do not own the chat')
  143. req =request.json
  144. if "knowledgebases" in req:
  145. if not req.get("knowledgebases"):
  146. return get_error_data_result(retmsg="`knowledgebases` can't be empty value")
  147. kb_list = []
  148. for kb in req.get("knowledgebases"):
  149. if not kb["id"]:
  150. return get_error_data_result(retmsg="knowledgebase needs id")
  151. if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id):
  152. return get_error_data_result(retmsg="you do not own the knowledgebase")
  153. # if not DocumentService.query(kb_id=kb["id"]):
  154. # return get_error_data_result(retmsg="There is a invalid knowledgebase")
  155. kb_list.append(kb["id"])
  156. req["kb_ids"] = kb_list
  157. llm = req.get("llm")
  158. if llm:
  159. if "model_name" in llm:
  160. req["llm_id"] = llm.pop("model_name")
  161. req["llm_setting"] = req.pop("llm")
  162. e, tenant = TenantService.get_by_id(tenant_id)
  163. if not e:
  164. return get_error_data_result(retmsg="Tenant not found!")
  165. # prompt
  166. prompt = req.get("prompt")
  167. key_mapping = {"parameters": "variables",
  168. "prologue": "opener",
  169. "quote": "show_quote",
  170. "system": "prompt",
  171. "rerank_id": "rerank_model",
  172. "vector_similarity_weight": "keywords_similarity_weight"}
  173. key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
  174. if prompt:
  175. for new_key, old_key in key_mapping.items():
  176. if old_key in prompt:
  177. prompt[new_key] = prompt.pop(old_key)
  178. for key in key_list:
  179. if key in prompt:
  180. req[key] = prompt.pop(key)
  181. req["prompt_config"] = req.pop("prompt")
  182. e, res = DialogService.get_by_id(chat_id)
  183. res = res.to_json()
  184. if "llm_id" in req:
  185. if not TenantLLMService.query(llm_name=req["llm_id"]):
  186. return get_error_data_result(retmsg="The `model_name` does not exist.")
  187. if "name" in req:
  188. if not req.get("name"):
  189. return get_error_data_result(retmsg="`name` is not empty.")
  190. if req["name"].lower() != res["name"].lower() \
  191. and len(
  192. DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0:
  193. return get_error_data_result(retmsg="Duplicated chat name in updating dataset.")
  194. if "prompt_config" in req:
  195. res["prompt_config"].update(req["prompt_config"])
  196. for p in res["prompt_config"]["parameters"]:
  197. if p["optional"]:
  198. continue
  199. if res["prompt_config"]["system"].find("{%s}" % p["key"]) < 0:
  200. return get_error_data_result(retmsg="Parameter '{}' is not used".format(p["key"]))
  201. if "llm_setting" in req:
  202. res["llm_setting"].update(req["llm_setting"])
  203. req["prompt_config"] = res["prompt_config"]
  204. req["llm_setting"] = res["llm_setting"]
  205. # avatar
  206. if "avatar" in req:
  207. req["icon"] = req.pop("avatar")
  208. if "knowledgebases" in req:
  209. req.pop("knowledgebases")
  210. if not DialogService.update_by_id(chat_id, req):
  211. return get_error_data_result(retmsg="Chat not found!")
  212. return get_result()
  213. @manager.route('/chat', methods=['DELETE'])
  214. @token_required
  215. def delete(tenant_id):
  216. req = request.json
  217. ids = req.get("ids")
  218. if not ids:
  219. return get_error_data_result(retmsg="`ids` are required")
  220. for id in ids:
  221. if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value):
  222. return get_error_data_result(retmsg=f"You don't own the chat {id}")
  223. temp_dict = {"status": StatusEnum.INVALID.value}
  224. DialogService.update_by_id(id, temp_dict)
  225. return get_result()
  226. @manager.route('/chat', methods=['GET'])
  227. @token_required
  228. def list_chat(tenant_id):
  229. id = request.args.get("id")
  230. name = request.args.get("name")
  231. chat = DialogService.query(id=id,name=name,status=StatusEnum.VALID.value)
  232. if not chat:
  233. return get_error_data_result(retmsg="The chat doesn't exist")
  234. page_number = int(request.args.get("page", 1))
  235. items_per_page = int(request.args.get("page_size", 1024))
  236. orderby = request.args.get("orderby", "create_time")
  237. if request.args.get("desc") == "False" or request.args.get("desc") == "false":
  238. desc = False
  239. else:
  240. desc = True
  241. chats = DialogService.get_list(tenant_id,page_number,items_per_page,orderby,desc,id,name)
  242. if not chats:
  243. return get_result(data=[])
  244. list_assts = []
  245. renamed_dict = {}
  246. key_mapping = {"parameters": "variables",
  247. "prologue": "opener",
  248. "quote": "show_quote",
  249. "system": "prompt",
  250. "rerank_id": "rerank_model",
  251. "vector_similarity_weight": "keywords_similarity_weight"}
  252. key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"]
  253. for res in chats:
  254. for key, value in res["prompt_config"].items():
  255. new_key = key_mapping.get(key, key)
  256. renamed_dict[new_key] = value
  257. res["prompt"] = renamed_dict
  258. del res["prompt_config"]
  259. new_dict = {"similarity_threshold": res["similarity_threshold"],
  260. "keywords_similarity_weight": res["vector_similarity_weight"],
  261. "top_n": res["top_n"],
  262. "rerank_model": res['rerank_id']}
  263. res["prompt"].update(new_dict)
  264. for key in key_list:
  265. del res[key]
  266. res["llm"] = res.pop("llm_setting")
  267. res["llm"]["model_name"] = res.pop("llm_id")
  268. kb_list = []
  269. for kb_id in res["kb_ids"]:
  270. kb = KnowledgebaseService.query(id=kb_id)
  271. if not kb :
  272. return get_error_data_result(retmsg=f"Don't exist the kb {kb_id}")
  273. kb_list.append(kb[0].to_json())
  274. del res["kb_ids"]
  275. res["knowledgebases"] = kb_list
  276. res["avatar"] = res.pop("icon")
  277. list_assts.append(res)
  278. return get_result(data=list_assts)