You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

session.py 2.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import json
  2. from .base import Base
  3. class Session(Base):
  4. def __init__(self, rag, res_dict):
  5. self.id = None
  6. self.name = "New session"
  7. self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]
  8. for key,value in res_dict.items():
  9. if key =="chat_id" and value is not None:
  10. self.chat_id = None
  11. self.__session_type = "chat"
  12. if key == "agent_id" and value is not None:
  13. self.agent_id = None
  14. self.__session_type = "agent"
  15. super().__init__(rag, res_dict)
  16. def ask(self, question,stream=True):
  17. if self.__session_type == "agent":
  18. res=self._ask_agent(question,stream)
  19. elif self.__session_type == "chat":
  20. res=self._ask_chat(question,stream)
  21. for line in res.iter_lines():
  22. line = line.decode("utf-8")
  23. if line.startswith("{"):
  24. json_data = json.loads(line)
  25. raise Exception(json_data["message"])
  26. if line.startswith("data:"):
  27. json_data = json.loads(line[5:])
  28. if not json_data["data"]:
  29. answer = json_data["data"]["answer"]
  30. reference = json_data["data"]["reference"]
  31. temp_dict = {
  32. "content": answer,
  33. "role": "assistant"
  34. }
  35. if "chunks" in reference:
  36. chunks = reference["chunks"]
  37. temp_dict["reference"] = chunks
  38. message = Message(self.rag, temp_dict)
  39. yield message
  40. def _ask_chat(self, question: str, stream: bool):
  41. res = self.post(f"/chats/{self.chat_id}/completions",
  42. {"question": question, "stream": True,"session_id":self.id}, stream=stream)
  43. return res
  44. def _ask_agent(self,question:str,stream:bool):
  45. res = self.post(f"/agents/{self.agent_id}/completions",
  46. {"question": question, "stream": True,"session_id":self.id}, stream=stream)
  47. return res
  48. def update(self,update_message):
  49. res = self.put(f"/chats/{self.chat_id}/sessions/{self.id}",
  50. update_message)
  51. res = res.json()
  52. if res.get("code") != 0:
  53. raise Exception(res.get("message"))
  54. class Message(Base):
  55. def __init__(self, rag, res_dict):
  56. self.content = "Hi! I am your assistant,can I help you?"
  57. self.reference = None
  58. self.role = "assistant"
  59. self.prompt = None
  60. self.id = None
  61. super().__init__(rag, res_dict)