You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

session.py 2.8KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import json
  2. from .base import Base
  3. class Session(Base):
  4. def __init__(self, rag, res_dict):
  5. self.id = None
  6. self.name = "New session"
  7. self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]
  8. for key,value in res_dict.items():
  9. if key =="chat_id" and value is not None:
  10. self.chat_id = None
  11. self.__session_type = "chat"
  12. if key == "agent_id" and value is not None:
  13. self.agent_id = None
  14. self.__session_type = "agent"
  15. super().__init__(rag, res_dict)
  16. def ask(self, question="",stream=True,**kwargs):
  17. if self.__session_type == "agent":
  18. res=self._ask_agent(question,stream)
  19. elif self.__session_type == "chat":
  20. res=self._ask_chat(question,stream,**kwargs)
  21. for line in res.iter_lines():
  22. line = line.decode("utf-8")
  23. if line.startswith("{"):
  24. json_data = json.loads(line)
  25. raise Exception(json_data["message"])
  26. if not line.startswith("data:"):
  27. continue
  28. json_data = json.loads(line[5:])
  29. if json_data["data"] is True or json_data["data"].get("running_status"):
  30. continue
  31. answer = json_data["data"]["answer"]
  32. reference = json_data["data"].get("reference", {})
  33. temp_dict = {
  34. "content": answer,
  35. "role": "assistant"
  36. }
  37. if reference and "chunks" in reference:
  38. chunks = reference["chunks"]
  39. temp_dict["reference"] = chunks
  40. message = Message(self.rag, temp_dict)
  41. yield message
  42. def _ask_chat(self, question: str, stream: bool,**kwargs):
  43. json_data={"question": question, "stream": True,"session_id":self.id}
  44. json_data.update(kwargs)
  45. res = self.post(f"/chats/{self.chat_id}/completions",
  46. json_data, stream=stream)
  47. return res
  48. def _ask_agent(self,question:str,stream:bool):
  49. res = self.post(f"/agents/{self.agent_id}/completions",
  50. {"question": question, "stream": True,"session_id":self.id}, stream=stream)
  51. return res
  52. def update(self,update_message):
  53. res = self.put(f"/chats/{self.chat_id}/sessions/{self.id}",
  54. update_message)
  55. res = res.json()
  56. if res.get("code") != 0:
  57. raise Exception(res.get("message"))
  58. class Message(Base):
  59. def __init__(self, rag, res_dict):
  60. self.content = "Hi! I am your assistant,can I help you?"
  61. self.reference = None
  62. self.role = "assistant"
  63. self.prompt = None
  64. self.id = None
  65. super().__init__(rag, res_dict)