Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

canvas.py 8.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import importlib
  17. import json
  18. import traceback
  19. from abc import ABC
  20. from copy import deepcopy
  21. from functools import partial
  22. import pandas as pd
  23. from graph.component import component_class
  24. from graph.component.base import ComponentBase
  25. from graph.settings import flow_logger, DEBUG
  26. class Canvas(ABC):
  27. """
  28. dsl = {
  29. "components": {
  30. "begin": {
  31. "obj":{
  32. "component_name": "Begin",
  33. "params": {},
  34. },
  35. "downstream": ["answer_0"],
  36. "upstream": [],
  37. },
  38. "answer_0": {
  39. "obj": {
  40. "component_name": "Answer",
  41. "params": {}
  42. },
  43. "downstream": ["retrieval_0"],
  44. "upstream": ["begin", "generate_0"],
  45. },
  46. "retrieval_0": {
  47. "obj": {
  48. "component_name": "Retrieval",
  49. "params": {}
  50. },
  51. "downstream": ["generate_0"],
  52. "upstream": ["answer_0"],
  53. },
  54. "generate_0": {
  55. "obj": {
  56. "component_name": "Generate",
  57. "params": {}
  58. },
  59. "downstream": ["answer_0"],
  60. "upstream": ["retrieval_0"],
  61. }
  62. },
  63. "history": [],
  64. "messages": [],
  65. "reference": [],
  66. "path": [["begin"]],
  67. "answer": []
  68. }
  69. """
  70. def __init__(self, dsl: str, tenant_id=None):
  71. self.path = []
  72. self.history = []
  73. self.messages = []
  74. self.answer = []
  75. self.components = {}
  76. self.dsl = json.loads(dsl) if dsl else {
  77. "components": {
  78. "begin": {
  79. "obj": {
  80. "component_name": "Begin",
  81. "params": {
  82. "prologue": "Hi there!"
  83. }
  84. },
  85. "downstream": [],
  86. "upstream": []
  87. }
  88. },
  89. "history": [],
  90. "messages": [],
  91. "reference": [],
  92. "path": [],
  93. "answer": []
  94. }
  95. self._tenant_id = tenant_id
  96. self._embed_id = ""
  97. self.load()
  98. def load(self):
  99. assert self.dsl.get("components", {}).get("begin"), "There have to be a 'Begin' component."
  100. self.components = self.dsl["components"]
  101. for k, cpn in self.components.items():
  102. param = component_class(cpn["obj"]["component_name"] + "Param")()
  103. param.update(cpn["obj"]["params"])
  104. param.check()
  105. cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
  106. if cpn["obj"].component_name == "Categorize":
  107. for _,desc in param.category_description.items():
  108. if desc["to"] not in cpn["downstream"]:
  109. cpn["downstream"].append(desc["to"])
  110. self.path = self.dsl["path"]
  111. self.history = self.dsl["history"]
  112. self.messages = self.dsl["messages"]
  113. self.answer = self.dsl["answer"]
  114. self.reference = self.dsl["reference"]
  115. self._embed_id = self.dsl.get("embed_id", "")
  116. def __str__(self):
  117. self.dsl["path"] = self.path
  118. self.dsl["history"] = self.history
  119. self.dsl["messages"] = self.messages
  120. self.dsl["answer"] = self.answer
  121. self.dsl["reference"] = self.reference
  122. self.dsl["embed_id"] = self._embed_id
  123. dsl = deepcopy(self.dsl)
  124. for k, cpn in self.components.items():
  125. dsl["components"][k]["obj"] = json.loads(str(cpn["obj"]))
  126. return json.dumps(dsl, ensure_ascii=False)
  127. def reset(self):
  128. self.path = []
  129. self.history = []
  130. self.messages = []
  131. self.answer = []
  132. self.reference = []
  133. self.components = {}
  134. self._embed_id = ""
  135. def run(self, **kwargs):
  136. ans = ""
  137. if self.answer:
  138. cpn_id = self.answer[0]
  139. self.answer.pop(0)
  140. try:
  141. ans = self.components[cpn_id]["obj"].run(self.history, **kwargs)
  142. except Exception as e:
  143. ans = ComponentBase.be_output(str(e))
  144. self.path[-1].append(cpn_id)
  145. self.history.append(("assistant", ans.to_dict("records")))
  146. return ans
  147. if not self.path:
  148. self.components["begin"]["obj"].run(self.history, **kwargs)
  149. self.path.append(["begin"])
  150. self.path.append([])
  151. ran = -1
  152. def prepare2run(cpns):
  153. nonlocal ran, ans
  154. for c in cpns:
  155. cpn = self.components[c]["obj"]
  156. if cpn.component_name == "Answer":
  157. self.answer.append(c)
  158. else:
  159. if DEBUG: print("RUN: ", c)
  160. ans = cpn.run(self.history, **kwargs)
  161. self.path[-1].append(c)
  162. ran += 1
  163. prepare2run(self.components[self.path[-2][-1]]["downstream"])
  164. while ran < len(self.path[-1]):
  165. if DEBUG: print(ran, self.path)
  166. cpn_id = self.path[-1][ran]
  167. cpn = self.get_component(cpn_id)
  168. if not cpn["downstream"]: break
  169. if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
  170. switch_out = cpn["obj"].output()[1].iloc[0, 0]
  171. assert switch_out in self.components, \
  172. "{}'s output: {} not valid.".format(cpn_id, switch_out)
  173. try:
  174. prepare2run([switch_out])
  175. except Exception as e:
  176. for p in [c for p in self.path for c in p][::-1]:
  177. if p.lower().find("answer") >= 0:
  178. self.get_component(p)["obj"].set_exception(e)
  179. prepare2run([p])
  180. break
  181. traceback.print_exc()
  182. continue
  183. try:
  184. prepare2run(cpn["downstream"])
  185. except Exception as e:
  186. for p in [c for p in self.path for c in p][::-1]:
  187. if p.lower().find("answer") >= 0:
  188. self.get_component(p)["obj"].set_exception(e)
  189. prepare2run([p])
  190. break
  191. traceback.print_exc()
  192. if self.answer:
  193. cpn_id = self.answer[0]
  194. self.answer.pop(0)
  195. ans = self.components[cpn_id]["obj"].run(self.history, **kwargs)
  196. self.path[-1].append(cpn_id)
  197. if kwargs.get("stream"):
  198. assert isinstance(ans, partial)
  199. return ans
  200. self.history.append(("assistant", ans.to_dict("records")))
  201. return ans
  202. def get_component(self, cpn_id):
  203. return self.components[cpn_id]
  204. def get_tenant_id(self):
  205. return self._tenant_id
  206. def get_history(self, window_size):
  207. convs = []
  208. for role, obj in self.history[window_size * -2:]:
  209. convs.append({"role": role, "content": (obj if role == "user" else
  210. '\n'.join(pd.DataFrame(obj)['content']))})
  211. return convs
  212. def add_user_input(self, question):
  213. self.history.append(("user", question))
  214. def set_embedding_model(self, embed_id):
  215. self._embed_id = embed_id
  216. def get_embedding_model(self):
  217. return self._embed_id