Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

ragflow.py 10.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Optional
  16. import requests
  17. from .modules.agent import Agent
  18. from .modules.chat import Chat
  19. from .modules.chunk import Chunk
  20. from .modules.dataset import DataSet
  21. class RAGFlow:
  22. def __init__(self, api_key, base_url, version="v1"):
  23. """
  24. api_url: http://<host_address>/api/v1
  25. """
  26. self.user_key = api_key
  27. self.api_url = f"{base_url}/api/{version}"
  28. self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
  29. def post(self, path, json=None, stream=False, files=None):
  30. res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
  31. return res
  32. def get(self, path, params=None, json=None):
  33. res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
  34. return res
  35. def delete(self, path, json):
  36. res = requests.delete(url=self.api_url + path, json=json, headers=self.authorization_header)
  37. return res
  38. def put(self, path, json):
  39. res = requests.put(url=self.api_url + path, json=json, headers=self.authorization_header)
  40. return res
  41. def create_dataset(
  42. self,
  43. name: str,
  44. avatar: Optional[str] = None,
  45. description: Optional[str] = None,
  46. embedding_model: Optional[str] = None,
  47. permission: str = "me",
  48. chunk_method: str = "naive",
  49. parser_config: Optional[DataSet.ParserConfig] = None,
  50. ) -> DataSet:
  51. payload = {
  52. "name": name,
  53. "avatar": avatar,
  54. "description": description,
  55. "embedding_model": embedding_model,
  56. "permission": permission,
  57. "chunk_method": chunk_method,
  58. }
  59. if parser_config is not None:
  60. payload["parser_config"] = parser_config.to_json()
  61. res = self.post("/datasets", payload)
  62. res = res.json()
  63. if res.get("code") == 0:
  64. return DataSet(self, res["data"])
  65. raise Exception(res["message"])
  66. def delete_datasets(self, ids: list[str] | None = None):
  67. res = self.delete("/datasets", {"ids": ids})
  68. res = res.json()
  69. if res.get("code") != 0:
  70. raise Exception(res["message"])
  71. def get_dataset(self, name: str):
  72. _list = self.list_datasets(name=name)
  73. if len(_list) > 0:
  74. return _list[0]
  75. raise Exception("Dataset %s not found" % name)
  76. def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[DataSet]:
  77. res = self.get(
  78. "/datasets",
  79. {
  80. "page": page,
  81. "page_size": page_size,
  82. "orderby": orderby,
  83. "desc": desc,
  84. "id": id,
  85. "name": name,
  86. },
  87. )
  88. res = res.json()
  89. result_list = []
  90. if res.get("code") == 0:
  91. for data in res["data"]:
  92. result_list.append(DataSet(self, data))
  93. return result_list
  94. raise Exception(res["message"])
  95. def create_chat(self, name: str, avatar: str = "", dataset_ids=None, llm: Chat.LLM | None = None, prompt: Chat.Prompt | None = None) -> Chat:
  96. if dataset_ids is None:
  97. dataset_ids = []
  98. dataset_list = []
  99. for id in dataset_ids:
  100. dataset_list.append(id)
  101. if llm is None:
  102. llm = Chat.LLM(
  103. self,
  104. {
  105. "model_name": None,
  106. "temperature": 0.1,
  107. "top_p": 0.3,
  108. "presence_penalty": 0.4,
  109. "frequency_penalty": 0.7,
  110. "max_tokens": 512,
  111. },
  112. )
  113. if prompt is None:
  114. prompt = Chat.Prompt(
  115. self,
  116. {
  117. "similarity_threshold": 0.2,
  118. "keywords_similarity_weight": 0.7,
  119. "top_n": 8,
  120. "top_k": 1024,
  121. "variables": [{"key": "knowledge", "optional": True}],
  122. "rerank_model": "",
  123. "empty_response": None,
  124. "opener": None,
  125. "show_quote": True,
  126. "prompt": None,
  127. },
  128. )
  129. if prompt.opener is None:
  130. prompt.opener = "Hi! I'm your assistant. What can I do for you?"
  131. if prompt.prompt is None:
  132. prompt.prompt = (
  133. "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. "
  134. "Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, "
  135. "your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' "
  136. "Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base."
  137. )
  138. temp_dict = {"name": name, "avatar": avatar, "dataset_ids": dataset_list if dataset_list else [], "llm": llm.to_json(), "prompt": prompt.to_json()}
  139. res = self.post("/chats", temp_dict)
  140. res = res.json()
  141. if res.get("code") == 0:
  142. return Chat(self, res["data"])
  143. raise Exception(res["message"])
  144. def delete_chats(self, ids: list[str] | None = None):
  145. res = self.delete("/chats", {"ids": ids})
  146. res = res.json()
  147. if res.get("code") != 0:
  148. raise Exception(res["message"])
  149. def list_chats(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[Chat]:
  150. res = self.get(
  151. "/chats",
  152. {
  153. "page": page,
  154. "page_size": page_size,
  155. "orderby": orderby,
  156. "desc": desc,
  157. "id": id,
  158. "name": name,
  159. },
  160. )
  161. res = res.json()
  162. result_list = []
  163. if res.get("code") == 0:
  164. for data in res["data"]:
  165. result_list.append(Chat(self, data))
  166. return result_list
  167. raise Exception(res["message"])
  168. def retrieve(
  169. self,
  170. dataset_ids,
  171. document_ids=None,
  172. question="",
  173. page=1,
  174. page_size=30,
  175. similarity_threshold=0.2,
  176. vector_similarity_weight=0.3,
  177. top_k=1024,
  178. rerank_id: str | None = None,
  179. keyword: bool = False,
  180. cross_languages: list[str]|None = None
  181. ):
  182. if document_ids is None:
  183. document_ids = []
  184. data_json = {
  185. "page": page,
  186. "page_size": page_size,
  187. "similarity_threshold": similarity_threshold,
  188. "vector_similarity_weight": vector_similarity_weight,
  189. "top_k": top_k,
  190. "rerank_id": rerank_id,
  191. "keyword": keyword,
  192. "question": question,
  193. "dataset_ids": dataset_ids,
  194. "document_ids": document_ids,
  195. "cross_languages": cross_languages
  196. }
  197. # Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
  198. res = self.post("/retrieval", json=data_json)
  199. res = res.json()
  200. if res.get("code") == 0:
  201. chunks = []
  202. for chunk_data in res["data"].get("chunks"):
  203. chunk = Chunk(self, chunk_data)
  204. chunks.append(chunk)
  205. return chunks
  206. raise Exception(res.get("message"))
  207. def list_agents(self, page: int = 1, page_size: int = 30, orderby: str = "update_time", desc: bool = True, id: str | None = None, title: str | None = None) -> list[Agent]:
  208. res = self.get(
  209. "/agents",
  210. {
  211. "page": page,
  212. "page_size": page_size,
  213. "orderby": orderby,
  214. "desc": desc,
  215. "id": id,
  216. "title": title,
  217. },
  218. )
  219. res = res.json()
  220. result_list = []
  221. if res.get("code") == 0:
  222. for data in res["data"]:
  223. result_list.append(Agent(self, data))
  224. return result_list
  225. raise Exception(res["message"])
  226. def create_agent(self, title: str, dsl: dict, description: str | None = None) -> None:
  227. req = {"title": title, "dsl": dsl}
  228. if description is not None:
  229. req["description"] = description
  230. res = self.post("/agents", req)
  231. res = res.json()
  232. if res.get("code") != 0:
  233. raise Exception(res["message"])
  234. def update_agent(self, agent_id: str, title: str | None = None, description: str | None = None, dsl: dict | None = None) -> None:
  235. req = {}
  236. if title is not None:
  237. req["title"] = title
  238. if description is not None:
  239. req["description"] = description
  240. if dsl is not None:
  241. req["dsl"] = dsl
  242. res = self.put(f"/agents/{agent_id}", req)
  243. res = res.json()
  244. if res.get("code") != 0:
  245. raise Exception(res["message"])
  246. def delete_agent(self, agent_id: str) -> None:
  247. res = self.delete(f"/agents/{agent_id}", {})
  248. res = res.json()
  249. if res.get("code") != 0:
  250. raise Exception(res["message"])