選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

canvas.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import importlib
  17. import json
  18. import traceback
  19. from abc import ABC
  20. from copy import deepcopy
  21. from functools import partial
  22. import pandas as pd
  23. from agent.component import component_class
  24. from agent.component.base import ComponentBase
  25. from agent.settings import flow_logger, DEBUG
  26. class Canvas(ABC):
  27. """
  28. dsl = {
  29. "components": {
  30. "begin": {
  31. "obj":{
  32. "component_name": "Begin",
  33. "params": {},
  34. },
  35. "downstream": ["answer_0"],
  36. "upstream": [],
  37. },
  38. "answer_0": {
  39. "obj": {
  40. "component_name": "Answer",
  41. "params": {}
  42. },
  43. "downstream": ["retrieval_0"],
  44. "upstream": ["begin", "generate_0"],
  45. },
  46. "retrieval_0": {
  47. "obj": {
  48. "component_name": "Retrieval",
  49. "params": {}
  50. },
  51. "downstream": ["generate_0"],
  52. "upstream": ["answer_0"],
  53. },
  54. "generate_0": {
  55. "obj": {
  56. "component_name": "Generate",
  57. "params": {}
  58. },
  59. "downstream": ["answer_0"],
  60. "upstream": ["retrieval_0"],
  61. }
  62. },
  63. "history": [],
  64. "messages": [],
  65. "reference": [],
  66. "path": [["begin"]],
  67. "answer": []
  68. }
  69. """
  70. def __init__(self, dsl: str, tenant_id=None):
  71. self.path = []
  72. self.history = []
  73. self.messages = []
  74. self.answer = []
  75. self.components = {}
  76. self.dsl = json.loads(dsl) if dsl else {
  77. "components": {
  78. "begin": {
  79. "obj": {
  80. "component_name": "Begin",
  81. "params": {
  82. "prologue": "Hi there!"
  83. }
  84. },
  85. "downstream": [],
  86. "upstream": []
  87. }
  88. },
  89. "history": [],
  90. "messages": [],
  91. "reference": [],
  92. "path": [],
  93. "answer": []
  94. }
  95. self._tenant_id = tenant_id
  96. self._embed_id = ""
  97. self.load()
  98. def load(self):
  99. self.components = self.dsl["components"]
  100. cpn_nms = set([])
  101. for k, cpn in self.components.items():
  102. cpn_nms.add(cpn["obj"]["component_name"])
  103. assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
  104. assert "Answer" in cpn_nms, "There have to be an 'Answer' component."
  105. for k, cpn in self.components.items():
  106. cpn_nms.add(cpn["obj"]["component_name"])
  107. param = component_class(cpn["obj"]["component_name"] + "Param")()
  108. param.update(cpn["obj"]["params"])
  109. param.check()
  110. cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
  111. if cpn["obj"].component_name == "Categorize":
  112. for _, desc in param.category_description.items():
  113. if desc["to"] not in cpn["downstream"]:
  114. cpn["downstream"].append(desc["to"])
  115. self.path = self.dsl["path"]
  116. self.history = self.dsl["history"]
  117. self.messages = self.dsl["messages"]
  118. self.answer = self.dsl["answer"]
  119. self.reference = self.dsl["reference"]
  120. self._embed_id = self.dsl.get("embed_id", "")
  121. def __str__(self):
  122. self.dsl["path"] = self.path
  123. self.dsl["history"] = self.history
  124. self.dsl["messages"] = self.messages
  125. self.dsl["answer"] = self.answer
  126. self.dsl["reference"] = self.reference
  127. self.dsl["embed_id"] = self._embed_id
  128. dsl = {
  129. "components": {}
  130. }
  131. for k in self.dsl.keys():
  132. if k in ["components"]:continue
  133. dsl[k] = deepcopy(self.dsl[k])
  134. for k, cpn in self.components.items():
  135. if k not in dsl["components"]:
  136. dsl["components"][k] = {}
  137. for c in cpn.keys():
  138. if c == "obj":
  139. dsl["components"][k][c] = json.loads(str(cpn["obj"]))
  140. continue
  141. dsl["components"][k][c] = deepcopy(cpn[c])
  142. return json.dumps(dsl, ensure_ascii=False)
  143. def reset(self):
  144. self.path = []
  145. self.history = []
  146. self.messages = []
  147. self.answer = []
  148. self.reference = []
  149. for k, cpn in self.components.items():
  150. self.components[k]["obj"].reset()
  151. self._embed_id = ""
  152. def run(self, **kwargs):
  153. ans = ""
  154. if self.answer:
  155. cpn_id = self.answer[0]
  156. self.answer.pop(0)
  157. try:
  158. ans = self.components[cpn_id]["obj"].run(self.history, **kwargs)
  159. except Exception as e:
  160. ans = ComponentBase.be_output(str(e))
  161. self.path[-1].append(cpn_id)
  162. if kwargs.get("stream"):
  163. assert isinstance(ans, partial)
  164. return ans
  165. self.history.append(("assistant", ans.to_dict("records")))
  166. return ans
  167. if not self.path:
  168. self.components["begin"]["obj"].run(self.history, **kwargs)
  169. self.path.append(["begin"])
  170. self.path.append([])
  171. ran = -1
  172. def prepare2run(cpns):
  173. nonlocal ran, ans
  174. for c in cpns:
  175. if self.path[-1] and c == self.path[-1][-1]: continue
  176. cpn = self.components[c]["obj"]
  177. if cpn.component_name == "Answer":
  178. self.answer.append(c)
  179. else:
  180. if DEBUG: print("RUN: ", c)
  181. if cpn.component_name == "Generate":
  182. cpids = cpn.get_dependent_components()
  183. if any([c not in self.path[-1] for c in cpids]):
  184. continue
  185. ans = cpn.run(self.history, **kwargs)
  186. self.path[-1].append(c)
  187. ran += 1
  188. prepare2run(self.components[self.path[-2][-1]]["downstream"])
  189. while 0 <= ran < len(self.path[-1]):
  190. if DEBUG: print(ran, self.path)
  191. cpn_id = self.path[-1][ran]
  192. cpn = self.get_component(cpn_id)
  193. if not cpn["downstream"]: break
  194. loop = self._find_loop()
  195. if loop: raise OverflowError(f"Too much loops: {loop}")
  196. if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
  197. switch_out = cpn["obj"].output()[1].iloc[0, 0]
  198. assert switch_out in self.components, \
  199. "{}'s output: {} not valid.".format(cpn_id, switch_out)
  200. try:
  201. prepare2run([switch_out])
  202. except Exception as e:
  203. for p in [c for p in self.path for c in p][::-1]:
  204. if p.lower().find("answer") >= 0:
  205. self.get_component(p)["obj"].set_exception(e)
  206. prepare2run([p])
  207. break
  208. traceback.print_exc()
  209. break
  210. continue
  211. try:
  212. prepare2run(cpn["downstream"])
  213. except Exception as e:
  214. for p in [c for p in self.path for c in p][::-1]:
  215. if p.lower().find("answer") >= 0:
  216. self.get_component(p)["obj"].set_exception(e)
  217. prepare2run([p])
  218. break
  219. traceback.print_exc()
  220. break
  221. if self.answer:
  222. cpn_id = self.answer[0]
  223. self.answer.pop(0)
  224. ans = self.components[cpn_id]["obj"].run(self.history, **kwargs)
  225. self.path[-1].append(cpn_id)
  226. if kwargs.get("stream"):
  227. assert isinstance(ans, partial)
  228. return ans
  229. self.history.append(("assistant", ans.to_dict("records")))
  230. return ans
  231. def get_component(self, cpn_id):
  232. return self.components[cpn_id]
  233. def get_tenant_id(self):
  234. return self._tenant_id
  235. def get_history(self, window_size):
  236. convs = []
  237. for role, obj in self.history[window_size * -2:]:
  238. convs.append({"role": role, "content": (obj if role == "user" else
  239. '\n'.join(pd.DataFrame(obj)['content']))})
  240. return convs
  241. def add_user_input(self, question):
  242. self.history.append(("user", question))
  243. def set_embedding_model(self, embed_id):
  244. self._embed_id = embed_id
  245. def get_embedding_model(self):
  246. return self._embed_id
  247. def _find_loop(self, max_loops=2):
  248. path = self.path[-1][::-1]
  249. if len(path) < 2: return False
  250. for i in range(len(path)):
  251. if path[i].lower().find("answer") >= 0:
  252. path = path[:i]
  253. break
  254. if len(path) < 2: return False
  255. for l in range(2, len(path) // 2):
  256. pat = ",".join(path[0:l])
  257. path_str = ",".join(path)
  258. if len(pat) >= len(path_str): return False
  259. loop = max_loops
  260. while path_str.find(pat) == 0 and loop >= 0:
  261. loop -= 1
  262. if len(pat)+1 >= len(path_str):
  263. return False
  264. path_str = path_str[len(pat)+1:]
  265. if loop < 0:
  266. pat = " => ".join([p.split(":")[0] for p in path[0:l]])
  267. return pat + " => " + pat
  268. return False