You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ragflow.py 8.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import List
  16. import requests
  17. from .modules.chat import Chat
  18. from .modules.chunk import Chunk
  19. from .modules.dataset import DataSet
  20. from .modules.document import Document
  21. class RAGFlow:
  22. def __init__(self, api_key, base_url, version='v1'):
  23. """
  24. api_url: http://<host_address>/api/v1
  25. """
  26. self.user_key = api_key
  27. self.api_url = f"{base_url}/api/{version}"
  28. self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
  29. def post(self, path, json=None, stream=False, files=None):
  30. res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream,files=files)
  31. return res
  32. def get(self, path, params=None, json=None):
  33. res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header,json=json)
  34. return res
  35. def delete(self, path, json):
  36. res = requests.delete(url=self.api_url + path, json=json, headers=self.authorization_header)
  37. return res
  38. def put(self, path, json):
  39. res = requests.put(url=self.api_url + path, json= json,headers=self.authorization_header)
  40. return res
  41. def create_dataset(self, name: str, avatar: str = "", description: str = "", language: str = "English",
  42. permission: str = "me",chunk_method: str = "naive",
  43. parser_config: DataSet.ParserConfig = None) -> DataSet:
  44. if parser_config:
  45. parser_config = parser_config.to_json()
  46. res = self.post("/datasets",
  47. {"name": name, "avatar": avatar, "description": description, "language": language,
  48. "permission": permission, "chunk_method": chunk_method,
  49. "parser_config": parser_config
  50. }
  51. )
  52. res = res.json()
  53. if res.get("code") == 0:
  54. return DataSet(self, res["data"])
  55. raise Exception(res["message"])
  56. def delete_datasets(self, ids: List[str]):
  57. res = self.delete("/datasets",{"ids": ids})
  58. res=res.json()
  59. if res.get("code") != 0:
  60. raise Exception(res["message"])
  61. def get_dataset(self,name: str):
  62. _list = self.list_datasets(name=name)
  63. if len(_list) > 0:
  64. return _list[0]
  65. raise Exception("Dataset %s not found" % name)
  66. def list_datasets(self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True,
  67. id: str = None, name: str = None) -> \
  68. List[DataSet]:
  69. res = self.get("/datasets",
  70. {"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
  71. res = res.json()
  72. result_list = []
  73. if res.get("code") == 0:
  74. for data in res['data']:
  75. result_list.append(DataSet(self, data))
  76. return result_list
  77. raise Exception(res["message"])
  78. def create_chat(self, name: str, avatar: str = "", dataset_ids: List[str] = [],
  79. llm: Chat.LLM = None, prompt: Chat.Prompt = None) -> Chat:
  80. dataset_list = []
  81. for id in dataset_ids:
  82. dataset_list.append(id)
  83. if llm is None:
  84. llm = Chat.LLM(self, {"model_name": None,
  85. "temperature": 0.1,
  86. "top_p": 0.3,
  87. "presence_penalty": 0.4,
  88. "frequency_penalty": 0.7,
  89. "max_tokens": 512, })
  90. if prompt is None:
  91. prompt = Chat.Prompt(self, {"similarity_threshold": 0.2,
  92. "keywords_similarity_weight": 0.7,
  93. "top_n": 8,
  94. "variables": [{
  95. "key": "knowledge",
  96. "optional": True
  97. }], "rerank_model": "",
  98. "empty_response": None,
  99. "opener": None,
  100. "show_quote": True,
  101. "prompt": None})
  102. if prompt.opener is None:
  103. prompt.opener = "Hi! I'm your assistant, what can I do for you?"
  104. if prompt.prompt is None:
  105. prompt.prompt = (
  106. "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. "
  107. "Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, "
  108. "your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' "
  109. "Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base."
  110. )
  111. temp_dict = {"name": name,
  112. "avatar": avatar,
  113. "dataset_ids": dataset_list,
  114. "llm": llm.to_json(),
  115. "prompt": prompt.to_json()}
  116. res = self.post("/chats", temp_dict)
  117. res = res.json()
  118. if res.get("code") == 0:
  119. return Chat(self, res["data"])
  120. raise Exception(res["message"])
  121. def delete_chats(self,ids: List[str] = None,names: List[str] = None ) -> bool:
  122. res = self.delete('/chats',
  123. {"ids":ids, "names":names})
  124. res = res.json()
  125. if res.get("code") != 0:
  126. raise Exception(res["message"])
  127. def list_chats(self, page: int = 1, page_size: int = 1024, orderby: str = "create_time", desc: bool = True,
  128. id: str = None, name: str = None) -> List[Chat]:
  129. res = self.get("/chats",{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
  130. res = res.json()
  131. result_list = []
  132. if res.get("code") == 0:
  133. for data in res['data']:
  134. result_list.append(Chat(self, data))
  135. return result_list
  136. raise Exception(res["message"])
  137. def retrieve(self, dataset_ids, document_ids=None, question="", offset=1, limit=1024, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id:str=None, keyword:bool=False, ):
  138. if document_ids is None:
  139. document_ids = []
  140. data_json ={
  141. "offset": offset,
  142. "limit": limit,
  143. "similarity_threshold": similarity_threshold,
  144. "vector_similarity_weight": vector_similarity_weight,
  145. "top_k": top_k,
  146. "rerank_id": rerank_id,
  147. "keyword": keyword,
  148. "question": question,
  149. "datasets": dataset_ids,
  150. "documents": document_ids
  151. }
  152. # Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
  153. res = self.post(f'/retrieval',json=data_json)
  154. res = res.json()
  155. if res.get("code") ==0:
  156. chunks=[]
  157. for chunk_data in res["data"].get("chunks"):
  158. chunk=Chunk(self,chunk_data)
  159. chunks.append(chunk)
  160. return chunks
  161. raise Exception(res.get("message"))