You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

session.py 9.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. from uuid import uuid4
  18. from flask import request, Response
  19. from api.db import StatusEnum
  20. from api.db.services.dialog_service import DialogService, ConversationService, chat
  21. from api.utils import get_uuid
  22. from api.utils.api_utils import get_error_data_result
  23. from api.utils.api_utils import get_result, token_required
  24. @manager.route('/chat/<chat_id>/session', methods=['POST'])
  25. @token_required
  26. def create(tenant_id, chat_id):
  27. req = request.json
  28. req["dialog_id"] = chat_id
  29. dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
  30. if not dia:
  31. return get_error_data_result(retmsg="You do not own the assistant")
  32. conv = {
  33. "id": get_uuid(),
  34. "dialog_id": req["dialog_id"],
  35. "name": req.get("name", "New session"),
  36. "message": [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]
  37. }
  38. if not conv.get("name"):
  39. return get_error_data_result(retmsg="Name can not be empty.")
  40. ConversationService.save(**conv)
  41. e, conv = ConversationService.get_by_id(conv["id"])
  42. if not e:
  43. return get_error_data_result(retmsg="Fail to create a session!")
  44. conv = conv.to_dict()
  45. conv['messages'] = conv.pop("message")
  46. conv["chat_id"] = conv.pop("dialog_id")
  47. del conv["reference"]
  48. return get_result(data=conv)
  49. @manager.route('/chat/<chat_id>/session/<session_id>', methods=['PUT'])
  50. @token_required
  51. def update(tenant_id, chat_id, session_id):
  52. req = request.json
  53. req["dialog_id"] = chat_id
  54. conv_id = session_id
  55. conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
  56. if not conv:
  57. return get_error_data_result(retmsg="Session does not exist")
  58. if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  59. return get_error_data_result(retmsg="You do not own the session")
  60. if "message" in req or "messages" in req:
  61. return get_error_data_result(retmsg="Message can not be change")
  62. if "reference" in req:
  63. return get_error_data_result(retmsg="Reference can not be change")
  64. if "name" in req and not req.get("name"):
  65. return get_error_data_result(retmsg="Name can not be empty.")
  66. if not ConversationService.update_by_id(conv_id, req):
  67. return get_error_data_result(retmsg="Session updates error")
  68. return get_result()
  69. @manager.route('/chat/<chat_id>/session/<session_id>/completion', methods=['POST'])
  70. @token_required
  71. def completion(tenant_id, chat_id, session_id):
  72. req = request.json
  73. # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
  74. # {"role": "user", "content": "上海有吗?"}
  75. # ]}
  76. if not req.get("question"):
  77. return get_error_data_result(retmsg="Please input your question.")
  78. conv = ConversationService.query(id=session_id, dialog_id=chat_id)
  79. if not conv:
  80. return get_error_data_result(retmsg="Session does not exist")
  81. conv = conv[0]
  82. if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  83. return get_error_data_result(retmsg="You do not own the session")
  84. msg = []
  85. question = {
  86. "content": req.get("question"),
  87. "role": "user",
  88. "id": str(uuid4())
  89. }
  90. conv.message.append(question)
  91. for m in conv.message:
  92. if m["role"] == "system": continue
  93. if m["role"] == "assistant" and not msg: continue
  94. msg.append(m)
  95. message_id = msg[-1].get("id")
  96. e, dia = DialogService.get_by_id(conv.dialog_id)
  97. if not conv.reference:
  98. conv.reference = []
  99. conv.message.append({"role": "assistant", "content": "", "id": message_id})
  100. conv.reference.append({"chunks": [], "doc_aggs": []})
  101. def fillin_conv(ans):
  102. nonlocal conv, message_id
  103. if not conv.reference:
  104. conv.reference.append(ans["reference"])
  105. else:
  106. conv.reference[-1] = ans["reference"]
  107. conv.message[-1] = {"role": "assistant", "content": ans["answer"],
  108. "id": message_id, "prompt": ans.get("prompt", "")}
  109. ans["id"] = message_id
  110. def stream():
  111. nonlocal dia, msg, req, conv
  112. try:
  113. for ans in chat(dia, msg, **req):
  114. fillin_conv(ans)
  115. yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
  116. ConversationService.update_by_id(conv.id, conv.to_dict())
  117. except Exception as e:
  118. yield "data:" + json.dumps({"code": 500, "message": str(e),
  119. "data": {"answer": "**ERROR**: " + str(e), "reference": []}},
  120. ensure_ascii=False) + "\n\n"
  121. yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
  122. if req.get("stream", True):
  123. resp = Response(stream(), mimetype="text/event-stream")
  124. resp.headers.add_header("Cache-control", "no-cache")
  125. resp.headers.add_header("Connection", "keep-alive")
  126. resp.headers.add_header("X-Accel-Buffering", "no")
  127. resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
  128. return resp
  129. else:
  130. answer = None
  131. for ans in chat(dia, msg, **req):
  132. answer = ans
  133. fillin_conv(ans)
  134. ConversationService.update_by_id(conv.id, conv.to_dict())
  135. break
  136. return get_result(data=answer)
  137. @manager.route('/chat/<chat_id>/session', methods=['GET'])
  138. @token_required
  139. def list(chat_id, tenant_id):
  140. if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
  141. return get_error_data_result(retmsg=f"You don't own the assistant {chat_id}.")
  142. id = request.args.get("id")
  143. name = request.args.get("name")
  144. session = ConversationService.query(id=id, name=name, dialog_id=chat_id)
  145. if not session:
  146. return get_error_data_result(retmsg="The session doesn't exist")
  147. page_number = int(request.args.get("page", 1))
  148. items_per_page = int(request.args.get("page_size", 1024))
  149. orderby = request.args.get("orderby", "create_time")
  150. if request.args.get("desc") == "False" or request.args.get("desc") == "false":
  151. desc = False
  152. else:
  153. desc = True
  154. convs = ConversationService.get_list(chat_id, page_number, items_per_page, orderby, desc, id, name)
  155. if not convs:
  156. return get_result(data=[])
  157. for conv in convs:
  158. conv['messages'] = conv.pop("message")
  159. conv["chat"] = conv.pop("dialog_id")
  160. if conv["reference"]:
  161. messages = conv["messages"]
  162. message_num = 0
  163. chunk_num = 0
  164. while message_num < len(messages):
  165. if message_num != 0 and messages[message_num]["role"] != "user":
  166. chunk_list = []
  167. if "chunks" in conv["reference"][chunk_num]:
  168. chunks = conv["reference"][chunk_num]["chunks"]
  169. for chunk in chunks:
  170. new_chunk = {
  171. "id": chunk["chunk_id"],
  172. "content": chunk["content_with_weight"],
  173. "document_id": chunk["doc_id"],
  174. "document_name": chunk["docnm_kwd"],
  175. "knowledgebase_id": chunk["kb_id"],
  176. "image_id": chunk["img_id"],
  177. "similarity": chunk["similarity"],
  178. "vector_similarity": chunk["vector_similarity"],
  179. "term_similarity": chunk["term_similarity"],
  180. "positions": chunk["positions"],
  181. }
  182. chunk_list.append(new_chunk)
  183. chunk_num += 1
  184. messages[message_num]["reference"] = chunk_list
  185. message_num += 1
  186. del conv["reference"]
  187. return get_result(data=convs)
  188. @manager.route('/chat/<chat_id>/session', methods=["DELETE"])
  189. @token_required
  190. def delete(tenant_id, chat_id):
  191. if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  192. return get_error_data_result(retmsg="You don't own the chat")
  193. ids = request.json.get("ids")
  194. if not ids:
  195. return get_error_data_result(retmsg="`ids` is required in deleting operation")
  196. for id in ids:
  197. conv = ConversationService.query(id=id, dialog_id=chat_id)
  198. if not conv:
  199. return get_error_data_result(retmsg="The chat doesn't own the session")
  200. ConversationService.delete_by_id(id)
  201. return get_result()