You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

canvas.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import importlib
  17. import json
  18. import traceback
  19. from abc import ABC
  20. from copy import deepcopy
  21. from functools import partial
  22. import pandas as pd
  23. from graph.component import component_class
  24. from graph.component.base import ComponentBase
  25. from graph.settings import flow_logger, DEBUG
  26. class Canvas(ABC):
  27. """
  28. dsl = {
  29. "components": {
  30. "begin": {
  31. "obj":{
  32. "component_name": "Begin",
  33. "params": {},
  34. },
  35. "downstream": ["answer_0"],
  36. "upstream": [],
  37. },
  38. "answer_0": {
  39. "obj": {
  40. "component_name": "Answer",
  41. "params": {}
  42. },
  43. "downstream": ["retrieval_0"],
  44. "upstream": ["begin", "generate_0"],
  45. },
  46. "retrieval_0": {
  47. "obj": {
  48. "component_name": "Retrieval",
  49. "params": {}
  50. },
  51. "downstream": ["generate_0"],
  52. "upstream": ["answer_0"],
  53. },
  54. "generate_0": {
  55. "obj": {
  56. "component_name": "Generate",
  57. "params": {}
  58. },
  59. "downstream": ["answer_0"],
  60. "upstream": ["retrieval_0"],
  61. }
  62. },
  63. "history": [],
  64. "messages": [],
  65. "reference": [],
  66. "path": [["begin"]],
  67. "answer": []
  68. }
  69. """
  70. def __init__(self, dsl: str, tenant_id=None):
  71. self.path = []
  72. self.history = []
  73. self.messages = []
  74. self.answer = []
  75. self.components = {}
  76. self.dsl = json.loads(dsl) if dsl else {
  77. "components": {
  78. "begin": {
  79. "obj": {
  80. "component_name": "Begin",
  81. "params": {
  82. "prologue": "Hi there!"
  83. }
  84. },
  85. "downstream": [],
  86. "upstream": []
  87. }
  88. },
  89. "history": [],
  90. "messages": [],
  91. "reference": [],
  92. "path": [],
  93. "answer": []
  94. }
  95. self._tenant_id = tenant_id
  96. self._embed_id = ""
  97. self.load()
  98. def load(self):
  99. self.components = self.dsl["components"]
  100. cpn_nms = set([])
  101. for k, cpn in self.components.items():
  102. cpn_nms.add(cpn["obj"]["component_name"])
  103. assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
  104. assert "Answer" in cpn_nms, "There have to be an 'Answer' component."
  105. for k, cpn in self.components.items():
  106. cpn_nms.add(cpn["obj"]["component_name"])
  107. param = component_class(cpn["obj"]["component_name"] + "Param")()
  108. param.update(cpn["obj"]["params"])
  109. param.check()
  110. cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
  111. if cpn["obj"].component_name == "Categorize":
  112. for _, desc in param.category_description.items():
  113. if desc["to"] not in cpn["downstream"]:
  114. cpn["downstream"].append(desc["to"])
  115. self.path = self.dsl["path"]
  116. self.history = self.dsl["history"]
  117. self.messages = self.dsl["messages"]
  118. self.answer = self.dsl["answer"]
  119. self.reference = self.dsl["reference"]
  120. self._embed_id = self.dsl.get("embed_id", "")
  121. def __str__(self):
  122. self.dsl["path"] = self.path
  123. self.dsl["history"] = self.history
  124. self.dsl["messages"] = self.messages
  125. self.dsl["answer"] = self.answer
  126. self.dsl["reference"] = self.reference
  127. self.dsl["embed_id"] = self._embed_id
  128. dsl = {
  129. "components": {}
  130. }
  131. for k in self.dsl.keys():
  132. if k in ["components"]:continue
  133. dsl[k] = deepcopy(self.dsl[k])
  134. for k, cpn in self.components.items():
  135. if k not in dsl["components"]:
  136. dsl["components"][k] = {}
  137. for c in cpn.keys():
  138. if c == "obj":
  139. dsl["components"][k][c] = json.loads(str(cpn["obj"]))
  140. continue
  141. dsl["components"][k][c] = deepcopy(cpn[c])
  142. return json.dumps(dsl, ensure_ascii=False)
  143. def reset(self):
  144. self.path = []
  145. self.history = []
  146. self.messages = []
  147. self.answer = []
  148. self.reference = []
  149. for k, cpn in self.components.items():
  150. self.components[k]["obj"].reset()
  151. self._embed_id = ""
  152. def run(self, **kwargs):
  153. ans = ""
  154. if self.answer:
  155. cpn_id = self.answer[0]
  156. self.answer.pop(0)
  157. try:
  158. ans = self.components[cpn_id]["obj"].run(self.history, **kwargs)
  159. except Exception as e:
  160. ans = ComponentBase.be_output(str(e))
  161. self.path[-1].append(cpn_id)
  162. if kwargs.get("stream"):
  163. assert isinstance(ans, partial)
  164. return ans
  165. self.history.append(("assistant", ans.to_dict("records")))
  166. return ans
  167. if not self.path:
  168. self.components["begin"]["obj"].run(self.history, **kwargs)
  169. self.path.append(["begin"])
  170. self.path.append([])
  171. ran = -1
  172. def prepare2run(cpns):
  173. nonlocal ran, ans
  174. for c in cpns:
  175. cpn = self.components[c]["obj"]
  176. if cpn.component_name == "Answer":
  177. self.answer.append(c)
  178. else:
  179. if DEBUG: print("RUN: ", c)
  180. if cpn.component_name == "Generate":
  181. cpids = cpn.get_dependent_components()
  182. if any([c not in self.path[-1] for c in cpids]):
  183. continue
  184. ans = cpn.run(self.history, **kwargs)
  185. self.path[-1].append(c)
  186. ran += 1
  187. prepare2run(self.components[self.path[-2][-1]]["downstream"])
  188. while 0 <= ran < len(self.path[-1]):
  189. if DEBUG: print(ran, self.path)
  190. cpn_id = self.path[-1][ran]
  191. cpn = self.get_component(cpn_id)
  192. if not cpn["downstream"]: break
  193. loop = self._find_loop()
  194. if loop: raise OverflowError(f"Too much loops: {loop}")
  195. if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
  196. switch_out = cpn["obj"].output()[1].iloc[0, 0]
  197. assert switch_out in self.components, \
  198. "{}'s output: {} not valid.".format(cpn_id, switch_out)
  199. try:
  200. prepare2run([switch_out])
  201. except Exception as e:
  202. for p in [c for p in self.path for c in p][::-1]:
  203. if p.lower().find("answer") >= 0:
  204. self.get_component(p)["obj"].set_exception(e)
  205. prepare2run([p])
  206. break
  207. traceback.print_exc()
  208. break
  209. continue
  210. try:
  211. prepare2run(cpn["downstream"])
  212. except Exception as e:
  213. for p in [c for p in self.path for c in p][::-1]:
  214. if p.lower().find("answer") >= 0:
  215. self.get_component(p)["obj"].set_exception(e)
  216. prepare2run([p])
  217. break
  218. traceback.print_exc()
  219. break
  220. if self.answer:
  221. cpn_id = self.answer[0]
  222. self.answer.pop(0)
  223. ans = self.components[cpn_id]["obj"].run(self.history, **kwargs)
  224. self.path[-1].append(cpn_id)
  225. if kwargs.get("stream"):
  226. assert isinstance(ans, partial)
  227. return ans
  228. self.history.append(("assistant", ans.to_dict("records")))
  229. return ans
  230. def get_component(self, cpn_id):
  231. return self.components[cpn_id]
  232. def get_tenant_id(self):
  233. return self._tenant_id
  234. def get_history(self, window_size):
  235. convs = []
  236. for role, obj in self.history[window_size * -2:]:
  237. convs.append({"role": role, "content": (obj if role == "user" else
  238. '\n'.join(pd.DataFrame(obj)['content']))})
  239. return convs
  240. def add_user_input(self, question):
  241. self.history.append(("user", question))
  242. def set_embedding_model(self, embed_id):
  243. self._embed_id = embed_id
  244. def get_embedding_model(self):
  245. return self._embed_id
  246. def _find_loop(self, max_loops=2):
  247. path = self.path[-1][::-1]
  248. if len(path) < 2: return False
  249. for i in range(len(path)):
  250. if path[i].lower().find("answer") >= 0:
  251. path = path[:i]
  252. break
  253. if len(path) < 2: return False
  254. for l in range(2, len(path) // 2):
  255. pat = ",".join(path[0:l])
  256. path_str = ",".join(path)
  257. if len(pat) >= len(path_str): return False
  258. loop = max_loops
  259. while path_str.find(pat) == 0 and loop >= 0:
  260. loop -= 1
  261. if len(pat)+1 >= len(path_str):
  262. return False
  263. path_str = path_str[len(pat)+1:]
  264. if loop < 0:
  265. pat = " => ".join([p.split(":")[0] for p in path[0:l]])
  266. return pat + " => " + pat
  267. return False