您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

session.py 9.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. from uuid import uuid4
  18. from flask import request, Response
  19. from api.db import StatusEnum
  20. from api.db.services.dialog_service import DialogService, ConversationService, chat
  21. from api.utils import get_uuid
  22. from api.utils.api_utils import get_error_data_result
  23. from api.utils.api_utils import get_result, token_required
  24. @manager.route('/chat/<chat_id>/session', methods=['POST'])
  25. @token_required
  26. def create(tenant_id,chat_id):
  27. req = request.json
  28. req["dialog_id"] = chat_id
  29. dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
  30. if not dia:
  31. return get_error_data_result(retmsg="You do not own the assistant")
  32. conv = {
  33. "id": get_uuid(),
  34. "dialog_id": req["dialog_id"],
  35. "name": req.get("name", "New session"),
  36. "message": [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]
  37. }
  38. if not conv.get("name"):
  39. return get_error_data_result(retmsg="`name` can not be empty.")
  40. ConversationService.save(**conv)
  41. e, conv = ConversationService.get_by_id(conv["id"])
  42. if not e:
  43. return get_error_data_result(retmsg="Fail to create a session!")
  44. conv = conv.to_dict()
  45. conv['messages'] = conv.pop("message")
  46. conv["chat_id"] = conv.pop("dialog_id")
  47. del conv["reference"]
  48. return get_result(data=conv)
  49. @manager.route('/chat/<chat_id>/session/<session_id>', methods=['PUT'])
  50. @token_required
  51. def update(tenant_id,chat_id,session_id):
  52. req = request.json
  53. req["dialog_id"] = chat_id
  54. conv_id = session_id
  55. conv = ConversationService.query(id=conv_id,dialog_id=chat_id)
  56. if not conv:
  57. return get_error_data_result(retmsg="Session does not exist")
  58. if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  59. return get_error_data_result(retmsg="You do not own the session")
  60. if "message" in req or "messages" in req:
  61. return get_error_data_result(retmsg="`message` can not be change")
  62. if "reference" in req:
  63. return get_error_data_result(retmsg="`reference` can not be change")
  64. if "name" in req and not req.get("name"):
  65. return get_error_data_result(retmsg="`name` can not be empty.")
  66. if not ConversationService.update_by_id(conv_id, req):
  67. return get_error_data_result(retmsg="Session updates error")
  68. return get_result()
  69. @manager.route('/chat/<chat_id>/completion', methods=['POST'])
  70. @token_required
  71. def completion(tenant_id,chat_id):
  72. req = request.json
  73. # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
  74. # {"role": "user", "content": "上海有吗?"}
  75. # ]}
  76. if not req.get("session_id"):
  77. conv = {
  78. "id": get_uuid(),
  79. "dialog_id": chat_id,
  80. "name": req.get("name", "New session"),
  81. "message": [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]
  82. }
  83. if not conv.get("name"):
  84. return get_error_data_result(retmsg="`name` can not be empty.")
  85. ConversationService.save(**conv)
  86. e, conv = ConversationService.get_by_id(conv["id"])
  87. session_id=conv.id
  88. else:
  89. session_id = req.get("session_id")
  90. if not req.get("question"):
  91. return get_error_data_result(retmsg="Please input your question.")
  92. conv = ConversationService.query(id=session_id,dialog_id=chat_id)
  93. if not conv:
  94. return get_error_data_result(retmsg="Session does not exist")
  95. conv = conv[0]
  96. if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  97. return get_error_data_result(retmsg="You do not own the session")
  98. msg = []
  99. question = {
  100. "content": req.get("question"),
  101. "role": "user",
  102. "id": str(uuid4())
  103. }
  104. conv.message.append(question)
  105. for m in conv.message:
  106. if m["role"] == "system": continue
  107. if m["role"] == "assistant" and not msg: continue
  108. msg.append(m)
  109. message_id = msg[-1].get("id")
  110. e, dia = DialogService.get_by_id(conv.dialog_id)
  111. if not conv.reference:
  112. conv.reference = []
  113. conv.message.append({"role": "assistant", "content": "", "id": message_id})
  114. conv.reference.append({"chunks": [], "doc_aggs": []})
  115. def fillin_conv(ans):
  116. nonlocal conv, message_id
  117. if not conv.reference:
  118. conv.reference.append(ans["reference"])
  119. else:
  120. conv.reference[-1] = ans["reference"]
  121. conv.message[-1] = {"role": "assistant", "content": ans["answer"],
  122. "id": message_id, "prompt": ans.get("prompt", "")}
  123. ans["id"] = message_id
  124. ans["session_id"]=session_id
  125. def stream():
  126. nonlocal dia, msg, req, conv
  127. try:
  128. for ans in chat(dia, msg, **req):
  129. fillin_conv(ans)
  130. yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
  131. ConversationService.update_by_id(conv.id, conv.to_dict())
  132. except Exception as e:
  133. yield "data:" + json.dumps({"code": 500, "message": str(e),
  134. "data": {"answer": "**ERROR**: " + str(e),"reference": []}},
  135. ensure_ascii=False) + "\n\n"
  136. yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
  137. if req.get("stream", True):
  138. resp = Response(stream(), mimetype="text/event-stream")
  139. resp.headers.add_header("Cache-control", "no-cache")
  140. resp.headers.add_header("Connection", "keep-alive")
  141. resp.headers.add_header("X-Accel-Buffering", "no")
  142. resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
  143. return resp
  144. else:
  145. answer = None
  146. for ans in chat(dia, msg, **req):
  147. answer = ans
  148. fillin_conv(ans)
  149. ConversationService.update_by_id(conv.id, conv.to_dict())
  150. break
  151. return get_result(data=answer)
  152. @manager.route('/chat/<chat_id>/session', methods=['GET'])
  153. @token_required
  154. def list(chat_id,tenant_id):
  155. if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
  156. return get_error_data_result(retmsg=f"You don't own the assistant {chat_id}.")
  157. id = request.args.get("id")
  158. name = request.args.get("name")
  159. session = ConversationService.query(id=id,name=name,dialog_id=chat_id)
  160. if not session:
  161. return get_error_data_result(retmsg="The session doesn't exist")
  162. page_number = int(request.args.get("page", 1))
  163. items_per_page = int(request.args.get("page_size", 1024))
  164. orderby = request.args.get("orderby", "create_time")
  165. if request.args.get("desc") == "False" or request.args.get("desc") == "false":
  166. desc = False
  167. else:
  168. desc = True
  169. convs = ConversationService.get_list(chat_id,page_number,items_per_page,orderby,desc,id,name)
  170. if not convs:
  171. return get_result(data=[])
  172. for conv in convs:
  173. conv['messages'] = conv.pop("message")
  174. conv["chat"] = conv.pop("dialog_id")
  175. if conv["reference"]:
  176. messages = conv["messages"]
  177. message_num = 0
  178. chunk_num = 0
  179. while message_num < len(messages):
  180. if message_num != 0 and messages[message_num]["role"] != "user":
  181. chunk_list = []
  182. if "chunks" in conv["reference"][chunk_num]:
  183. chunks = conv["reference"][chunk_num]["chunks"]
  184. for chunk in chunks:
  185. new_chunk = {
  186. "id": chunk["chunk_id"],
  187. "content": chunk["content_with_weight"],
  188. "document_id": chunk["doc_id"],
  189. "document_name": chunk["docnm_kwd"],
  190. "knowledgebase_id": chunk["kb_id"],
  191. "image_id": chunk["img_id"],
  192. "similarity": chunk["similarity"],
  193. "vector_similarity": chunk["vector_similarity"],
  194. "term_similarity": chunk["term_similarity"],
  195. "positions": chunk["positions"],
  196. }
  197. chunk_list.append(new_chunk)
  198. chunk_num += 1
  199. messages[message_num]["reference"] = chunk_list
  200. message_num += 1
  201. del conv["reference"]
  202. return get_result(data=convs)
  203. @manager.route('/chat/<chat_id>/session', methods=["DELETE"])
  204. @token_required
  205. def delete(tenant_id,chat_id):
  206. if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
  207. return get_error_data_result(retmsg="You don't own the chat")
  208. ids = request.json.get("ids")
  209. if not ids:
  210. return get_error_data_result(retmsg="`ids` is required in deleting operation")
  211. for id in ids:
  212. conv = ConversationService.query(id=id,dialog_id=chat_id)
  213. if not conv:
  214. return get_error_data_result(retmsg="The chat doesn't own the session")
  215. ConversationService.delete_by_id(id)
  216. return get_result()