Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

canvas.py 9.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import importlib
  17. import json
  18. import traceback
  19. from abc import ABC
  20. from copy import deepcopy
  21. from functools import partial
  22. import pandas as pd
  23. from graph.component import component_class
  24. from graph.component.base import ComponentBase
  25. from graph.settings import flow_logger, DEBUG
  26. class Canvas(ABC):
  27. """
  28. dsl = {
  29. "components": {
  30. "begin": {
  31. "obj":{
  32. "component_name": "Begin",
  33. "params": {},
  34. },
  35. "downstream": ["answer_0"],
  36. "upstream": [],
  37. },
  38. "answer_0": {
  39. "obj": {
  40. "component_name": "Answer",
  41. "params": {}
  42. },
  43. "downstream": ["retrieval_0"],
  44. "upstream": ["begin", "generate_0"],
  45. },
  46. "retrieval_0": {
  47. "obj": {
  48. "component_name": "Retrieval",
  49. "params": {}
  50. },
  51. "downstream": ["generate_0"],
  52. "upstream": ["answer_0"],
  53. },
  54. "generate_0": {
  55. "obj": {
  56. "component_name": "Generate",
  57. "params": {}
  58. },
  59. "downstream": ["answer_0"],
  60. "upstream": ["retrieval_0"],
  61. }
  62. },
  63. "history": [],
  64. "messages": [],
  65. "reference": [],
  66. "path": [["begin"]],
  67. "answer": []
  68. }
  69. """
  70. def __init__(self, dsl: str, tenant_id=None):
  71. self.path = []
  72. self.history = []
  73. self.messages = []
  74. self.answer = []
  75. self.components = {}
  76. self.dsl = json.loads(dsl) if dsl else {
  77. "components": {
  78. "begin": {
  79. "obj": {
  80. "component_name": "Begin",
  81. "params": {
  82. "prologue": "Hi there!"
  83. }
  84. },
  85. "downstream": [],
  86. "upstream": []
  87. }
  88. },
  89. "history": [],
  90. "messages": [],
  91. "reference": [],
  92. "path": [],
  93. "answer": []
  94. }
  95. self._tenant_id = tenant_id
  96. self._embed_id = ""
  97. self.load()
  98. def load(self):
  99. self.components = self.dsl["components"]
  100. cpn_nms = set([])
  101. for k, cpn in self.components.items():
  102. cpn_nms.add(cpn["obj"]["component_name"])
  103. assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
  104. assert "Answer" in cpn_nms, "There have to be an 'Answer' component."
  105. for k, cpn in self.components.items():
  106. cpn_nms.add(cpn["obj"]["component_name"])
  107. param = component_class(cpn["obj"]["component_name"] + "Param")()
  108. param.update(cpn["obj"]["params"])
  109. param.check()
  110. cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
  111. if cpn["obj"].component_name == "Categorize":
  112. for _, desc in param.category_description.items():
  113. if desc["to"] not in cpn["downstream"]:
  114. cpn["downstream"].append(desc["to"])
  115. self.path = self.dsl["path"]
  116. self.history = self.dsl["history"]
  117. self.messages = self.dsl["messages"]
  118. self.answer = self.dsl["answer"]
  119. self.reference = self.dsl["reference"]
  120. self._embed_id = self.dsl.get("embed_id", "")
  121. def __str__(self):
  122. self.dsl["path"] = self.path
  123. self.dsl["history"] = self.history
  124. self.dsl["messages"] = self.messages
  125. self.dsl["answer"] = self.answer
  126. self.dsl["reference"] = self.reference
  127. self.dsl["embed_id"] = self._embed_id
  128. dsl = {
  129. "components": {}
  130. }
  131. for k in self.dsl.keys():
  132. if k in ["components"]:continue
  133. dsl[k] = deepcopy(self.dsl[k])
  134. for k, cpn in self.components.items():
  135. if k not in dsl["components"]:
  136. dsl["components"][k] = {}
  137. for c in cpn.keys():
  138. if c == "obj":
  139. dsl["components"][k][c] = json.loads(str(cpn["obj"]))
  140. continue
  141. dsl["components"][k][c] = deepcopy(cpn[c])
  142. return json.dumps(dsl, ensure_ascii=False)
  143. def reset(self):
  144. self.path = []
  145. self.history = []
  146. self.messages = []
  147. self.answer = []
  148. self.reference = []
  149. for k, cpn in self.components.items():
  150. self.components[k]["obj"].reset()
  151. self._embed_id = ""
  152. def run(self, **kwargs):
  153. ans = ""
  154. if self.answer:
  155. cpn_id = self.answer[0]
  156. self.answer.pop(0)
  157. try:
  158. ans = self.components[cpn_id]["obj"].run(self.history, **kwargs)
  159. except Exception as e:
  160. ans = ComponentBase.be_output(str(e))
  161. self.path[-1].append(cpn_id)
  162. if kwargs.get("stream"):
  163. assert isinstance(ans, partial)
  164. return ans
  165. self.history.append(("assistant", ans.to_dict("records")))
  166. return ans
  167. if not self.path:
  168. self.components["begin"]["obj"].run(self.history, **kwargs)
  169. self.path.append(["begin"])
  170. self.path.append([])
  171. ran = -1
  172. def prepare2run(cpns):
  173. nonlocal ran, ans
  174. for c in cpns:
  175. cpn = self.components[c]["obj"]
  176. if cpn.component_name == "Answer":
  177. self.answer.append(c)
  178. else:
  179. if DEBUG: print("RUN: ", c)
  180. ans = cpn.run(self.history, **kwargs)
  181. self.path[-1].append(c)
  182. ran += 1
  183. prepare2run(self.components[self.path[-2][-1]]["downstream"])
  184. while 0 <= ran < len(self.path[-1]):
  185. if DEBUG: print(ran, self.path)
  186. cpn_id = self.path[-1][ran]
  187. cpn = self.get_component(cpn_id)
  188. if not cpn["downstream"]: break
  189. if self._find_loop(): raise OverflowError("Too much loops!")
  190. if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
  191. switch_out = cpn["obj"].output()[1].iloc[0, 0]
  192. assert switch_out in self.components, \
  193. "{}'s output: {} not valid.".format(cpn_id, switch_out)
  194. try:
  195. prepare2run([switch_out])
  196. except Exception as e:
  197. for p in [c for p in self.path for c in p][::-1]:
  198. if p.lower().find("answer") >= 0:
  199. self.get_component(p)["obj"].set_exception(e)
  200. prepare2run([p])
  201. break
  202. traceback.print_exc()
  203. continue
  204. try:
  205. prepare2run(cpn["downstream"])
  206. except Exception as e:
  207. for p in [c for p in self.path for c in p][::-1]:
  208. if p.lower().find("answer") >= 0:
  209. self.get_component(p)["obj"].set_exception(e)
  210. prepare2run([p])
  211. break
  212. traceback.print_exc()
  213. if self.answer:
  214. cpn_id = self.answer[0]
  215. self.answer.pop(0)
  216. ans = self.components[cpn_id]["obj"].run(self.history, **kwargs)
  217. self.path[-1].append(cpn_id)
  218. if kwargs.get("stream"):
  219. assert isinstance(ans, partial)
  220. return ans
  221. self.history.append(("assistant", ans.to_dict("records")))
  222. return ans
  223. def get_component(self, cpn_id):
  224. return self.components[cpn_id]
  225. def get_tenant_id(self):
  226. return self._tenant_id
  227. def get_history(self, window_size):
  228. convs = []
  229. for role, obj in self.history[window_size * -2:]:
  230. convs.append({"role": role, "content": (obj if role == "user" else
  231. '\n'.join(pd.DataFrame(obj)['content']))})
  232. return convs
  233. def add_user_input(self, question):
  234. self.history.append(("user", question))
  235. def set_embedding_model(self, embed_id):
  236. self._embed_id = embed_id
  237. def get_embedding_model(self):
  238. return self._embed_id
  239. def _find_loop(self, max_loops=2):
  240. path = self.path[-1][::-1]
  241. if len(path) < 2: return False
  242. for i in range(len(path)):
  243. if path[i].lower().find("answer") >= 0:
  244. path = path[:i]
  245. break
  246. if len(path) < 2: return False
  247. for l in range(1, len(path) // 2):
  248. pat = ",".join(path[0:l])
  249. path_str = ",".join(path)
  250. if len(pat) >= len(path_str): return False
  251. path_str = path_str[len(pat):]
  252. loop = max_loops
  253. while path_str.find(pat) >= 0 and loop >= 0:
  254. loop -= 1
  255. path_str = path_str[len(pat):]
  256. if loop < 0: return True
  257. return False