Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

canvas.py 8.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import importlib
  17. import json
  18. import traceback
  19. from abc import ABC
  20. from copy import deepcopy
  21. from functools import partial
  22. import pandas as pd
  23. from graph.component import component_class
  24. from graph.component.base import ComponentBase
  25. from graph.settings import flow_logger, DEBUG
  26. class Canvas(ABC):
  27. """
  28. dsl = {
  29. "components": {
  30. "begin": {
  31. "obj":{
  32. "component_name": "Begin",
  33. "params": {},
  34. },
  35. "downstream": ["answer_0"],
  36. "upstream": [],
  37. },
  38. "answer_0": {
  39. "obj": {
  40. "component_name": "Answer",
  41. "params": {}
  42. },
  43. "downstream": ["retrieval_0"],
  44. "upstream": ["begin", "generate_0"],
  45. },
  46. "retrieval_0": {
  47. "obj": {
  48. "component_name": "Retrieval",
  49. "params": {}
  50. },
  51. "downstream": ["generate_0"],
  52. "upstream": ["answer_0"],
  53. },
  54. "generate_0": {
  55. "obj": {
  56. "component_name": "Generate",
  57. "params": {}
  58. },
  59. "downstream": ["answer_0"],
  60. "upstream": ["retrieval_0"],
  61. }
  62. },
  63. "history": [],
  64. "messages": [],
  65. "reference": [],
  66. "path": [["begin"]],
  67. "answer": []
  68. }
  69. """
  70. def __init__(self, dsl: str, tenant_id=None):
  71. self.path = []
  72. self.history = []
  73. self.messages = []
  74. self.answer = []
  75. self.components = {}
  76. self.dsl = json.loads(dsl) if dsl else {
  77. "components": {
  78. "begin": {
  79. "obj": {
  80. "component_name": "Begin",
  81. "params": {
  82. "prologue": "Hi there!"
  83. }
  84. },
  85. "downstream": [],
  86. "upstream": []
  87. }
  88. },
  89. "history": [],
  90. "messages": [],
  91. "reference": [],
  92. "path": [],
  93. "answer": []
  94. }
  95. self._tenant_id = tenant_id
  96. self._embed_id = ""
  97. self.load()
  98. def load(self):
  99. self.components = self.dsl["components"]
  100. cpn_nms = set([])
  101. for k, cpn in self.components.items():
  102. cpn_nms.add(cpn["obj"]["component_name"])
  103. assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
  104. assert "Answer" in cpn_nms, "There have to be an 'Answer' component."
  105. for k, cpn in self.components.items():
  106. cpn_nms.add(cpn["obj"]["component_name"])
  107. param = component_class(cpn["obj"]["component_name"] + "Param")()
  108. param.update(cpn["obj"]["params"])
  109. param.check()
  110. cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
  111. if cpn["obj"].component_name == "Categorize":
  112. for _, desc in param.category_description.items():
  113. if desc["to"] not in cpn["downstream"]:
  114. cpn["downstream"].append(desc["to"])
  115. self.path = self.dsl["path"]
  116. self.history = self.dsl["history"]
  117. self.messages = self.dsl["messages"]
  118. self.answer = self.dsl["answer"]
  119. self.reference = self.dsl["reference"]
  120. self._embed_id = self.dsl.get("embed_id", "")
  121. def __str__(self):
  122. self.dsl["path"] = self.path
  123. self.dsl["history"] = self.history
  124. self.dsl["messages"] = self.messages
  125. self.dsl["answer"] = self.answer
  126. self.dsl["reference"] = self.reference
  127. self.dsl["embed_id"] = self._embed_id
  128. dsl = deepcopy(self.dsl)
  129. for k, cpn in self.components.items():
  130. dsl["components"][k]["obj"] = json.loads(str(cpn["obj"]))
  131. return json.dumps(dsl, ensure_ascii=False)
  132. def reset(self):
  133. self.path = []
  134. self.history = []
  135. self.messages = []
  136. self.answer = []
  137. self.reference = []
  138. for k, cpn in self.components.items():
  139. self.components[k]["obj"].reset()
  140. self._embed_id = ""
  141. def run(self, **kwargs):
  142. ans = ""
  143. if self.answer:
  144. cpn_id = self.answer[0]
  145. self.answer.pop(0)
  146. try:
  147. ans = self.components[cpn_id]["obj"].run(self.history, **kwargs)
  148. except Exception as e:
  149. ans = ComponentBase.be_output(str(e))
  150. self.path[-1].append(cpn_id)
  151. self.history.append(("assistant", ans.to_dict("records")))
  152. return ans
  153. if not self.path:
  154. self.components["begin"]["obj"].run(self.history, **kwargs)
  155. self.path.append(["begin"])
  156. self.path.append([])
  157. ran = -1
  158. def prepare2run(cpns):
  159. nonlocal ran, ans
  160. for c in cpns:
  161. cpn = self.components[c]["obj"]
  162. if cpn.component_name == "Answer":
  163. self.answer.append(c)
  164. else:
  165. if DEBUG: print("RUN: ", c)
  166. ans = cpn.run(self.history, **kwargs)
  167. self.path[-1].append(c)
  168. ran += 1
  169. prepare2run(self.components[self.path[-2][-1]]["downstream"])
  170. while 0 <= ran < len(self.path[-1]):
  171. if DEBUG: print(ran, self.path)
  172. cpn_id = self.path[-1][ran]
  173. cpn = self.get_component(cpn_id)
  174. if not cpn["downstream"]: break
  175. if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
  176. switch_out = cpn["obj"].output()[1].iloc[0, 0]
  177. assert switch_out in self.components, \
  178. "{}'s output: {} not valid.".format(cpn_id, switch_out)
  179. try:
  180. prepare2run([switch_out])
  181. except Exception as e:
  182. for p in [c for p in self.path for c in p][::-1]:
  183. if p.lower().find("answer") >= 0:
  184. self.get_component(p)["obj"].set_exception(e)
  185. prepare2run([p])
  186. break
  187. traceback.print_exc()
  188. continue
  189. try:
  190. prepare2run(cpn["downstream"])
  191. except Exception as e:
  192. for p in [c for p in self.path for c in p][::-1]:
  193. if p.lower().find("answer") >= 0:
  194. self.get_component(p)["obj"].set_exception(e)
  195. prepare2run([p])
  196. break
  197. traceback.print_exc()
  198. if self.answer:
  199. cpn_id = self.answer[0]
  200. self.answer.pop(0)
  201. ans = self.components[cpn_id]["obj"].run(self.history, **kwargs)
  202. self.path[-1].append(cpn_id)
  203. if kwargs.get("stream"):
  204. assert isinstance(ans, partial)
  205. return ans
  206. self.history.append(("assistant", ans.to_dict("records")))
  207. return ans
  208. def get_component(self, cpn_id):
  209. return self.components[cpn_id]
  210. def get_tenant_id(self):
  211. return self._tenant_id
  212. def get_history(self, window_size):
  213. convs = []
  214. for role, obj in self.history[window_size * -2:]:
  215. convs.append({"role": role, "content": (obj if role == "user" else
  216. '\n'.join(pd.DataFrame(obj)['content']))})
  217. return convs
  218. def add_user_input(self, question):
  219. self.history.append(("user", question))
  220. def set_embedding_model(self, embed_id):
  221. self._embed_id = embed_id
  222. def get_embedding_model(self):
  223. return self._embed_id