You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

session.py 9.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. from uuid import uuid4
  18. from flask import request, Response
  19. from api.db import StatusEnum
  20. from api.db.services.dialog_service import DialogService, ConversationService, chat
  21. from api.utils import get_uuid
  22. from api.utils.api_utils import get_error_data_result
  23. from api.utils.api_utils import get_result, token_required
  24. @manager.route('/chats/<chat_id>/sessions', methods=['POST'])
  25. @token_required
  26. def create(tenant_id,chat_id):
  27. req = request.json
  28. req["dialog_id"] = chat_id
  29. dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
  30. if not dia:
  31. return get_error_data_result(retmsg="You do not own the assistant")
  32. conv = {
  33. "id": get_uuid(),
  34. "dialog_id": req["dialog_id"],
  35. "name": req.get("name", "New session"),
  36. "message": [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]
  37. }
  38. if not conv.get("name"):
  39. return get_error_data_result(retmsg="`name` can not be empty.")
  40. ConversationService.save(**conv)
  41. e, conv = ConversationService.get_by_id(conv["id"])
  42. if not e:
  43. return get_error_data_result(retmsg="Fail to create a session!")
  44. conv = conv.to_dict()
  45. conv['messages'] = conv.pop("message")
  46. conv["chat_id"] = conv.pop("dialog_id")
  47. del conv["reference"]
  48. return get_result(data=conv)
  49. @manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT'])
  50. @token_required
  51. def update(tenant_id,chat_id,session_id):
  52. req = request.json
  53. req["dialog_id"] = chat_id
  54. conv_id = session_id
  55. conv = ConversationService.query(id=conv_id,dialog_id=chat_id)
  56. if not conv:
  57. return get_error_data_result(retmsg="Session does not exist")
  58. if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  59. return get_error_data_result(retmsg="You do not own the session")
  60. if "message" in req or "messages" in req:
  61. return get_error_data_result(retmsg="`message` can not be change")
  62. if "reference" in req:
  63. return get_error_data_result(retmsg="`reference` can not be change")
  64. if "name" in req and not req.get("name"):
  65. return get_error_data_result(retmsg="`name` can not be empty.")
  66. if not ConversationService.update_by_id(conv_id, req):
  67. return get_error_data_result(retmsg="Session updates error")
  68. return get_result()
  69. @manager.route('/chats/<chat_id>/completions', methods=['POST'])
  70. @token_required
  71. def completion(tenant_id,chat_id):
  72. req = request.json
  73. if not req.get("session_id"):
  74. conv = {
  75. "id": get_uuid(),
  76. "dialog_id": chat_id,
  77. "name": req.get("name", "New session"),
  78. "message": [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]
  79. }
  80. if not conv.get("name"):
  81. return get_error_data_result(retmsg="`name` can not be empty.")
  82. ConversationService.save(**conv)
  83. e, conv = ConversationService.get_by_id(conv["id"])
  84. session_id=conv.id
  85. else:
  86. session_id = req.get("session_id")
  87. if not req.get("question"):
  88. return get_error_data_result(retmsg="Please input your question.")
  89. conv = ConversationService.query(id=session_id,dialog_id=chat_id)
  90. if not conv:
  91. return get_error_data_result(retmsg="Session does not exist")
  92. conv = conv[0]
  93. if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  94. return get_error_data_result(retmsg="You do not own the chat")
  95. msg = []
  96. question = {
  97. "content": req.get("question"),
  98. "role": "user",
  99. "id": str(uuid4())
  100. }
  101. conv.message.append(question)
  102. for m in conv.message:
  103. if m["role"] == "system": continue
  104. if m["role"] == "assistant" and not msg: continue
  105. msg.append(m)
  106. message_id = msg[-1].get("id")
  107. e, dia = DialogService.get_by_id(conv.dialog_id)
  108. if not conv.reference:
  109. conv.reference = []
  110. conv.message.append({"role": "assistant", "content": "", "id": message_id})
  111. conv.reference.append({"chunks": [], "doc_aggs": []})
  112. def fillin_conv(ans):
  113. nonlocal conv, message_id
  114. if not conv.reference:
  115. conv.reference.append(ans["reference"])
  116. else:
  117. conv.reference[-1] = ans["reference"]
  118. conv.message[-1] = {"role": "assistant", "content": ans["answer"],
  119. "id": message_id, "prompt": ans.get("prompt", "")}
  120. ans["id"] = message_id
  121. ans["session_id"]=session_id
  122. def stream():
  123. nonlocal dia, msg, req, conv
  124. try:
  125. for ans in chat(dia, msg, **req):
  126. fillin_conv(ans)
  127. yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
  128. ConversationService.update_by_id(conv.id, conv.to_dict())
  129. except Exception as e:
  130. yield "data:" + json.dumps({"code": 500, "message": str(e),
  131. "data": {"answer": "**ERROR**: " + str(e),"reference": []}},
  132. ensure_ascii=False) + "\n\n"
  133. yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
  134. if req.get("stream", True):
  135. resp = Response(stream(), mimetype="text/event-stream")
  136. resp.headers.add_header("Cache-control", "no-cache")
  137. resp.headers.add_header("Connection", "keep-alive")
  138. resp.headers.add_header("X-Accel-Buffering", "no")
  139. resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
  140. return resp
  141. else:
  142. answer = None
  143. for ans in chat(dia, msg, **req):
  144. answer = ans
  145. fillin_conv(ans)
  146. ConversationService.update_by_id(conv.id, conv.to_dict())
  147. break
  148. return get_result(data=answer)
  149. @manager.route('/chats/<chat_id>/sessions', methods=['GET'])
  150. @token_required
  151. def list(chat_id,tenant_id):
  152. if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
  153. return get_error_data_result(retmsg=f"You don't own the assistant {chat_id}.")
  154. id = request.args.get("id")
  155. name = request.args.get("name")
  156. page_number = int(request.args.get("page", 1))
  157. items_per_page = int(request.args.get("page_size", 1024))
  158. orderby = request.args.get("orderby", "create_time")
  159. if request.args.get("desc") == "False" or request.args.get("desc") == "false":
  160. desc = False
  161. else:
  162. desc = True
  163. convs = ConversationService.get_list(chat_id,page_number,items_per_page,orderby,desc,id,name)
  164. if not convs:
  165. return get_result(data=[])
  166. for conv in convs:
  167. conv['messages'] = conv.pop("message")
  168. infos = conv["messages"]
  169. for info in infos:
  170. if "prompt" in info:
  171. info.pop("prompt")
  172. conv["chat"] = conv.pop("dialog_id")
  173. if conv["reference"]:
  174. messages = conv["messages"]
  175. message_num = 0
  176. chunk_num = 0
  177. while message_num < len(messages):
  178. if message_num != 0 and messages[message_num]["role"] != "user":
  179. chunk_list = []
  180. if "chunks" in conv["reference"][chunk_num]:
  181. chunks = conv["reference"][chunk_num]["chunks"]
  182. for chunk in chunks:
  183. new_chunk = {
  184. "id": chunk["chunk_id"],
  185. "content": chunk["content_with_weight"],
  186. "document_id": chunk["doc_id"],
  187. "document_name": chunk["docnm_kwd"],
  188. "dataset_id": chunk["kb_id"],
  189. "image_id": chunk["img_id"],
  190. "similarity": chunk["similarity"],
  191. "vector_similarity": chunk["vector_similarity"],
  192. "term_similarity": chunk["term_similarity"],
  193. "positions": chunk["positions"],
  194. }
  195. chunk_list.append(new_chunk)
  196. chunk_num += 1
  197. messages[message_num]["reference"] = chunk_list
  198. message_num += 1
  199. del conv["reference"]
  200. return get_result(data=convs)
  201. @manager.route('/chats/<chat_id>/sessions', methods=["DELETE"])
  202. @token_required
  203. def delete(tenant_id,chat_id):
  204. if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  205. return get_error_data_result(retmsg="You don't own the chat")
  206. req = request.json
  207. convs = ConversationService.query(dialog_id=chat_id)
  208. if not req:
  209. ids = None
  210. else:
  211. ids=req.get("ids")
  212. if not ids:
  213. conv_list = []
  214. for conv in convs:
  215. conv_list.append(conv.id)
  216. else:
  217. conv_list=ids
  218. for id in conv_list:
  219. conv = ConversationService.query(id=id,dialog_id=chat_id)
  220. if not conv:
  221. return get_error_data_result(retmsg="The chat doesn't own the session")
  222. ConversationService.delete_by_id(id)
  223. return get_result()