Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

session.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. from uuid import uuid4
  18. from flask import request, Response
  19. from api.db import StatusEnum
  20. from api.db.services.dialog_service import DialogService, ConversationService, chat
  21. from api.settings import RetCode
  22. from api.utils import get_uuid
  23. from api.utils.api_utils import get_data_error_result
  24. from api.utils.api_utils import get_json_result, token_required
  25. @manager.route('/save', methods=['POST'])
  26. @token_required
  27. def set_conversation(tenant_id):
  28. req = request.json
  29. conv_id = req.get("id")
  30. if "assistant_id" in req:
  31. req["dialog_id"] = req.pop("assistant_id")
  32. if "id" in req:
  33. del req["id"]
  34. conv = ConversationService.query(id=conv_id)
  35. if not conv:
  36. return get_data_error_result(retmsg="Session does not exist")
  37. if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  38. return get_data_error_result(retmsg="You do not own the session")
  39. if req.get("dialog_id"):
  40. dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
  41. if not dia:
  42. return get_data_error_result(retmsg="You do not own the assistant")
  43. if "dialog_id" in req and not req.get("dialog_id"):
  44. return get_data_error_result(retmsg="assistant_id can not be empty.")
  45. if "message" in req:
  46. return get_data_error_result(retmsg="message can not be change")
  47. if "reference" in req:
  48. return get_data_error_result(retmsg="reference can not be change")
  49. if "name" in req and not req.get("name"):
  50. return get_data_error_result(retmsg="name can not be empty.")
  51. if not ConversationService.update_by_id(conv_id, req):
  52. return get_data_error_result(retmsg="Session updates error")
  53. return get_json_result(data=True)
  54. if not req.get("dialog_id"):
  55. return get_data_error_result(retmsg="assistant_id is required.")
  56. dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
  57. if not dia:
  58. return get_data_error_result(retmsg="You do not own the assistant")
  59. conv = {
  60. "id": get_uuid(),
  61. "dialog_id": req["dialog_id"],
  62. "name": req.get("name", "New session"),
  63. "message": [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]
  64. }
  65. if not conv.get("name"):
  66. return get_data_error_result(retmsg="name can not be empty.")
  67. ConversationService.save(**conv)
  68. e, conv = ConversationService.get_by_id(conv["id"])
  69. if not e:
  70. return get_data_error_result(retmsg="Fail to new session!")
  71. conv = conv.to_dict()
  72. conv['messages'] = conv.pop("message")
  73. conv["assistant_id"] = conv.pop("dialog_id")
  74. del conv["reference"]
  75. return get_json_result(data=conv)
  76. @manager.route('/completion', methods=['POST'])
  77. @token_required
  78. def completion(tenant_id):
  79. req = request.json
  80. # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
  81. # {"role": "user", "content": "上海有吗?"}
  82. # ]}
  83. if "session_id" not in req:
  84. return get_data_error_result(retmsg="session_id is required")
  85. conv = ConversationService.query(id=req["session_id"])
  86. if not conv:
  87. return get_data_error_result(retmsg="Session does not exist")
  88. conv = conv[0]
  89. if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  90. return get_data_error_result(retmsg="You do not own the session")
  91. msg = []
  92. question = {
  93. "content": req.get("question"),
  94. "role": "user",
  95. "id": str(uuid4())
  96. }
  97. conv.message.append(question)
  98. for m in conv.message:
  99. if m["role"] == "system": continue
  100. if m["role"] == "assistant" and not msg: continue
  101. msg.append(m)
  102. message_id = msg[-1].get("id")
  103. e, dia = DialogService.get_by_id(conv.dialog_id)
  104. del req["session_id"]
  105. if not conv.reference:
  106. conv.reference = []
  107. conv.message.append({"role": "assistant", "content": "", "id": message_id})
  108. conv.reference.append({"chunks": [], "doc_aggs": []})
  109. def fillin_conv(ans):
  110. nonlocal conv, message_id
  111. if not conv.reference:
  112. conv.reference.append(ans["reference"])
  113. else:
  114. conv.reference[-1] = ans["reference"]
  115. conv.message[-1] = {"role": "assistant", "content": ans["answer"],
  116. "id": message_id, "prompt": ans.get("prompt", "")}
  117. ans["id"] = message_id
  118. def stream():
  119. nonlocal dia, msg, req, conv
  120. try:
  121. for ans in chat(dia, msg, **req):
  122. fillin_conv(ans)
  123. yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
  124. ConversationService.update_by_id(conv.id, conv.to_dict())
  125. except Exception as e:
  126. yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
  127. "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
  128. ensure_ascii=False) + "\n\n"
  129. yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
  130. if req.get("stream", True):
  131. resp = Response(stream(), mimetype="text/event-stream")
  132. resp.headers.add_header("Cache-control", "no-cache")
  133. resp.headers.add_header("Connection", "keep-alive")
  134. resp.headers.add_header("X-Accel-Buffering", "no")
  135. resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
  136. return resp
  137. else:
  138. answer = None
  139. for ans in chat(dia, msg, **req):
  140. answer = ans
  141. fillin_conv(ans)
  142. ConversationService.update_by_id(conv.id, conv.to_dict())
  143. break
  144. return get_json_result(data=answer)
  145. @manager.route('/get', methods=['GET'])
  146. @token_required
  147. def get(tenant_id):
  148. req = request.args
  149. if "id" not in req:
  150. return get_data_error_result(retmsg="id is required")
  151. conv_id = req["id"]
  152. conv = ConversationService.query(id=conv_id)
  153. if not conv:
  154. return get_data_error_result(retmsg="Session does not exist")
  155. if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  156. return get_data_error_result(retmsg="You do not own the session")
  157. if "assistant_id" in req:
  158. if req["assistant_id"] != conv[0].dialog_id:
  159. return get_data_error_result(retmsg="The session doesn't belong to the assistant")
  160. conv = conv[0].to_dict()
  161. conv['messages'] = conv.pop("message")
  162. conv["assistant_id"] = conv.pop("dialog_id")
  163. if conv["reference"]:
  164. messages = conv["messages"]
  165. message_num = 0
  166. chunk_num = 0
  167. while message_num < len(messages):
  168. if message_num != 0 and messages[message_num]["role"] != "user":
  169. chunk_list = []
  170. if "chunks" in conv["reference"][chunk_num]:
  171. chunks = conv["reference"][chunk_num]["chunks"]
  172. for chunk in chunks:
  173. new_chunk = {
  174. "id": chunk["chunk_id"],
  175. "content": chunk["content_with_weight"],
  176. "document_id": chunk["doc_id"],
  177. "document_name": chunk["docnm_kwd"],
  178. "knowledgebase_id": chunk["kb_id"],
  179. "image_id": chunk["img_id"],
  180. "similarity": chunk["similarity"],
  181. "vector_similarity": chunk["vector_similarity"],
  182. "term_similarity": chunk["term_similarity"],
  183. "positions": chunk["positions"],
  184. }
  185. chunk_list.append(new_chunk)
  186. chunk_num += 1
  187. messages[message_num]["reference"] = chunk_list
  188. message_num += 1
  189. del conv["reference"]
  190. return get_json_result(data=conv)
  191. @manager.route('/list', methods=["GET"])
  192. @token_required
  193. def list(tenant_id):
  194. assistant_id = request.args["assistant_id"]
  195. if not DialogService.query(tenant_id=tenant_id, id=assistant_id, status=StatusEnum.VALID.value):
  196. return get_json_result(
  197. data=False, retmsg=f"You don't own the assistant.",
  198. retcode=RetCode.OPERATING_ERROR)
  199. convs = ConversationService.query(
  200. dialog_id=assistant_id,
  201. order_by=ConversationService.model.create_time,
  202. reverse=True)
  203. convs = [d.to_dict() for d in convs]
  204. for conv in convs:
  205. conv['messages'] = conv.pop("message")
  206. conv["assistant_id"] = conv.pop("dialog_id")
  207. if conv["reference"]:
  208. messages = conv["messages"]
  209. message_num = 0
  210. chunk_num = 0
  211. while message_num < len(messages):
  212. if message_num != 0 and messages[message_num]["role"] != "user":
  213. chunk_list = []
  214. if "chunks" in conv["reference"][chunk_num]:
  215. chunks = conv["reference"][chunk_num]["chunks"]
  216. for chunk in chunks:
  217. new_chunk = {
  218. "id": chunk["chunk_id"],
  219. "content": chunk["content_with_weight"],
  220. "document_id": chunk["doc_id"],
  221. "document_name": chunk["docnm_kwd"],
  222. "knowledgebase_id": chunk["kb_id"],
  223. "image_id": chunk["img_id"],
  224. "similarity": chunk["similarity"],
  225. "vector_similarity": chunk["vector_similarity"],
  226. "term_similarity": chunk["term_similarity"],
  227. "positions": chunk["positions"],
  228. }
  229. chunk_list.append(new_chunk)
  230. chunk_num += 1
  231. messages[message_num]["reference"] = chunk_list
  232. message_num += 1
  233. del conv["reference"]
  234. return get_json_result(data=convs)
  235. @manager.route('/delete', methods=["DELETE"])
  236. @token_required
  237. def delete(tenant_id):
  238. id = request.args.get("id")
  239. if not id:
  240. return get_data_error_result(retmsg="`id` is required in deleting operation")
  241. conv = ConversationService.query(id=id)
  242. if not conv:
  243. return get_data_error_result(retmsg="Session doesn't exist")
  244. conv = conv[0]
  245. if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  246. return get_data_error_result(retmsg="You don't own the session")
  247. ConversationService.delete_by_id(id)
  248. return get_json_result(data=True)