Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. import json
  18. import logging
  19. import time
  20. from concurrent.futures import ThreadPoolExecutor
  21. from copy import deepcopy
  22. from functools import partial
  23. from typing import Any, Union, Tuple
  24. from agent.component import component_class
  25. from agent.component.base import ComponentBase
  26. from api.db.services.file_service import FileService
  27. from api.utils import get_uuid, hash_str2int
  28. from rag.prompts.prompts import chunks_format
  29. from rag.utils.redis_conn import REDIS_CONN
  30. class Canvas:
  31. """
  32. dsl = {
  33. "components": {
  34. "begin": {
  35. "obj":{
  36. "component_name": "Begin",
  37. "params": {},
  38. },
  39. "downstream": ["answer_0"],
  40. "upstream": [],
  41. },
  42. "retrieval_0": {
  43. "obj": {
  44. "component_name": "Retrieval",
  45. "params": {}
  46. },
  47. "downstream": ["generate_0"],
  48. "upstream": ["answer_0"],
  49. },
  50. "generate_0": {
  51. "obj": {
  52. "component_name": "Generate",
  53. "params": {}
  54. },
  55. "downstream": ["answer_0"],
  56. "upstream": ["retrieval_0"],
  57. }
  58. },
  59. "history": [],
  60. "path": ["begin"],
  61. "retrieval": {"chunks": [], "doc_aggs": []},
  62. "globals": {
  63. "sys.query": "",
  64. "sys.user_id": tenant_id,
  65. "sys.conversation_turns": 0,
  66. "sys.files": []
  67. }
  68. }
  69. """
  70. def __init__(self, dsl: str, tenant_id=None, task_id=None):
  71. self.path = []
  72. self.history = []
  73. self.components = {}
  74. self.error = ""
  75. self.globals = {
  76. "sys.query": "",
  77. "sys.user_id": tenant_id,
  78. "sys.conversation_turns": 0,
  79. "sys.files": []
  80. }
  81. self.dsl = json.loads(dsl) if dsl else {
  82. "components": {
  83. "begin": {
  84. "obj": {
  85. "component_name": "Begin",
  86. "params": {
  87. "prologue": "Hi there!"
  88. }
  89. },
  90. "downstream": [],
  91. "upstream": [],
  92. "parent_id": ""
  93. }
  94. },
  95. "history": [],
  96. "path": [],
  97. "retrieval": [],
  98. "globals": {
  99. "sys.query": "",
  100. "sys.user_id": "",
  101. "sys.conversation_turns": 0,
  102. "sys.files": []
  103. }
  104. }
  105. self._tenant_id = tenant_id
  106. self.task_id = task_id if task_id else get_uuid()
  107. self.load()
  108. def load(self):
  109. self.components = self.dsl["components"]
  110. cpn_nms = set([])
  111. for k, cpn in self.components.items():
  112. cpn_nms.add(cpn["obj"]["component_name"])
  113. assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
  114. for k, cpn in self.components.items():
  115. cpn_nms.add(cpn["obj"]["component_name"])
  116. param = component_class(cpn["obj"]["component_name"] + "Param")()
  117. param.update(cpn["obj"]["params"])
  118. try:
  119. param.check()
  120. except Exception as e:
  121. raise ValueError(self.get_component_name(k) + f": {e}")
  122. cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
  123. self.path = self.dsl["path"]
  124. self.history = self.dsl["history"]
  125. self.globals = self.dsl["globals"]
  126. self.retrieval = self.dsl["retrieval"]
  127. self.memory = self.dsl.get("memory", [])
  128. def __str__(self):
  129. self.dsl["path"] = self.path
  130. self.dsl["history"] = self.history
  131. self.dsl["globals"] = self.globals
  132. self.dsl["task_id"] = self.task_id
  133. self.dsl["retrieval"] = self.retrieval
  134. self.dsl["memory"] = self.memory
  135. dsl = {
  136. "components": {}
  137. }
  138. for k in self.dsl.keys():
  139. if k in ["components"]:
  140. continue
  141. dsl[k] = deepcopy(self.dsl[k])
  142. for k, cpn in self.components.items():
  143. if k not in dsl["components"]:
  144. dsl["components"][k] = {}
  145. for c in cpn.keys():
  146. if c == "obj":
  147. dsl["components"][k][c] = json.loads(str(cpn["obj"]))
  148. continue
  149. dsl["components"][k][c] = deepcopy(cpn[c])
  150. return json.dumps(dsl, ensure_ascii=False)
  151. def reset(self, mem=False):
  152. self.path = []
  153. if not mem:
  154. self.history = []
  155. self.retrieval = []
  156. self.memory = []
  157. for k, cpn in self.components.items():
  158. self.components[k]["obj"].reset()
  159. for k in self.globals.keys():
  160. if isinstance(self.globals[k], str):
  161. self.globals[k] = ""
  162. elif isinstance(self.globals[k], int):
  163. self.globals[k] = 0
  164. elif isinstance(self.globals[k], float):
  165. self.globals[k] = 0
  166. elif isinstance(self.globals[k], list):
  167. self.globals[k] = []
  168. elif isinstance(self.globals[k], dict):
  169. self.globals[k] = {}
  170. else:
  171. self.globals[k] = None
  172. try:
  173. REDIS_CONN.delete(f"{self.task_id}-logs")
  174. except Exception as e:
  175. logging.exception(e)
  176. def get_component_name(self, cid):
  177. for n in self.dsl.get("graph", {}).get("nodes", []):
  178. if cid == n["id"]:
  179. return n["data"]["name"]
  180. return ""
  181. def run(self, **kwargs):
  182. st = time.perf_counter()
  183. self.message_id = get_uuid()
  184. created_at = int(time.time())
  185. self.add_user_input(kwargs.get("query"))
  186. for k in kwargs.keys():
  187. if k in ["query", "user_id", "files"] and kwargs[k]:
  188. if k == "files":
  189. self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
  190. else:
  191. self.globals[f"sys.{k}"] = kwargs[k]
  192. if not self.globals["sys.conversation_turns"] :
  193. self.globals["sys.conversation_turns"] = 0
  194. self.globals["sys.conversation_turns"] += 1
  195. def decorate(event, dt):
  196. nonlocal created_at
  197. return {
  198. "event": event,
  199. #"conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115",
  200. "message_id": self.message_id,
  201. "created_at": created_at,
  202. "task_id": self.task_id,
  203. "data": dt
  204. }
  205. if not self.path or self.path[-1].lower().find("userfillup") < 0:
  206. self.path.append("begin")
  207. self.retrieval.append({"chunks": [], "doc_aggs": []})
  208. yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
  209. self.retrieval.append({"chunks": {}, "doc_aggs": {}})
  210. def _run_batch(f, t):
  211. with ThreadPoolExecutor(max_workers=5) as executor:
  212. thr = []
  213. for i in range(f, t):
  214. cpn = self.get_component_obj(self.path[i])
  215. if cpn.component_name.lower() in ["begin", "userfillup"]:
  216. thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
  217. else:
  218. thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
  219. for t in thr:
  220. t.result()
  221. def _node_finished(cpn_obj):
  222. return decorate("node_finished",{
  223. "inputs": cpn_obj.get_input_values(),
  224. "outputs": cpn_obj.output(),
  225. "component_id": cpn_obj._id,
  226. "component_name": self.get_component_name(cpn_obj._id),
  227. "component_type": self.get_component_type(cpn_obj._id),
  228. "error": cpn_obj.error(),
  229. "elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"),
  230. "created_at": cpn_obj.output("_created_time"),
  231. })
  232. def _append_path(cpn_id):
  233. if self.path[-1] == cpn_id:
  234. return
  235. self.path.append(cpn_id)
  236. def _extend_path(cpn_ids):
  237. for cpn_id in cpn_ids:
  238. _append_path(cpn_id)
  239. self.error = ""
  240. idx = len(self.path) - 1
  241. partials = []
  242. while idx < len(self.path):
  243. to = len(self.path)
  244. for i in range(idx, to):
  245. yield decorate("node_started", {
  246. "inputs": None, "created_at": int(time.time()),
  247. "component_id": self.path[i],
  248. "component_name": self.get_component_name(self.path[i]),
  249. "component_type": self.get_component_type(self.path[i]),
  250. "thoughts": self.get_component_thoughts(self.path[i])
  251. })
  252. _run_batch(idx, to)
  253. # post processing of components invocation
  254. for i in range(idx, to):
  255. cpn = self.get_component(self.path[i])
  256. if cpn["obj"].component_name.lower() == "message":
  257. if isinstance(cpn["obj"].output("content"), partial):
  258. _m = ""
  259. for m in cpn["obj"].output("content")():
  260. if not m:
  261. continue
  262. if m == "<think>":
  263. yield decorate("message", {"content": "", "start_to_think": True})
  264. elif m == "</think>":
  265. yield decorate("message", {"content": "", "end_to_think": True})
  266. else:
  267. yield decorate("message", {"content": m})
  268. _m += m
  269. cpn["obj"].set_output("content", _m)
  270. else:
  271. yield decorate("message", {"content": cpn["obj"].output("content")})
  272. yield decorate("message_end", {"reference": self.get_reference()})
  273. while partials:
  274. _cpn = self.get_component(partials[0])
  275. if isinstance(_cpn["obj"].output("content"), partial):
  276. break
  277. yield _node_finished(_cpn["obj"])
  278. partials.pop(0)
  279. if cpn["obj"].error():
  280. ex = cpn["obj"].exception_handler()
  281. if ex and ex["comment"]:
  282. yield decorate("message", {"content": ex["comment"]})
  283. yield decorate("message_end", {})
  284. if ex and ex["goto"]:
  285. self.path.append(ex["goto"])
  286. elif not ex or not ex["default_value"]:
  287. self.error = cpn["obj"].error()
  288. if cpn["obj"].component_name.lower() != "iteration":
  289. if isinstance(cpn["obj"].output("content"), partial):
  290. if self.error:
  291. cpn["obj"].set_output("content", None)
  292. yield _node_finished(cpn["obj"])
  293. else:
  294. partials.append(self.path[i])
  295. else:
  296. yield _node_finished(cpn["obj"])
  297. if cpn["obj"].component_name.lower() == "iterationitem" and cpn["obj"].end():
  298. iter = cpn["obj"].get_parent()
  299. yield _node_finished(iter)
  300. _extend_path(self.get_component(cpn["parent_id"])["downstream"])
  301. elif cpn["obj"].component_name.lower() in ["categorize", "switch"]:
  302. _extend_path(cpn["obj"].output("_next"))
  303. elif cpn["obj"].component_name.lower() == "iteration":
  304. _append_path(cpn["obj"].get_start())
  305. elif not cpn["downstream"] and cpn["obj"].get_parent():
  306. _append_path(cpn["obj"].get_parent().get_start())
  307. else:
  308. _extend_path(cpn["downstream"])
  309. if self.error:
  310. logging.error(f"Runtime Error: {self.error}")
  311. break
  312. idx = to
  313. if any([self.get_component(c)["obj"].component_name.lower() == "userfillup" for c in self.path[idx:]]):
  314. path = [c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() == "userfillup"]
  315. path.extend([c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() != "userfillup"])
  316. another_inputs = {}
  317. tips = ""
  318. for c in path:
  319. o = self.get_component(c)["obj"]
  320. if o.component_name.lower() == "userfillup":
  321. another_inputs.update(o.get_input_elements())
  322. if o.get_param("enable_tips"):
  323. tips = o.get_param("tips")
  324. self.path = path
  325. yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips})
  326. return
  327. self.path = self.path[:idx]
  328. if not self.error:
  329. yield decorate("workflow_finished",
  330. {
  331. "inputs": kwargs.get("inputs"),
  332. "outputs": self.get_component_obj(self.path[-1]).output(),
  333. "elapsed_time": time.perf_counter() - st,
  334. "created_at": st,
  335. })
  336. self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
  337. def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
  338. return self.components.get(cpn_id)
  339. def get_component_obj(self, cpn_id) -> ComponentBase:
  340. return self.components.get(cpn_id)["obj"]
  341. def get_component_type(self, cpn_id) -> str:
  342. return self.components.get(cpn_id)["obj"].component_name
  343. def get_component_input_form(self, cpn_id) -> dict:
  344. return self.components.get(cpn_id)["obj"].get_input_form()
  345. def is_reff(self, exp: str) -> bool:
  346. exp = exp.strip("{").strip("}")
  347. if exp.find("@") < 0:
  348. return exp in self.globals
  349. arr = exp.split("@")
  350. if len(arr) != 2:
  351. return False
  352. if self.get_component(arr[0]) is None:
  353. return False
  354. return True
  355. def get_variable_value(self, exp: str) -> Any:
  356. exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
  357. if exp.find("@") < 0:
  358. return self.globals[exp]
  359. cpn_id, var_nm = exp.split("@")
  360. cpn = self.get_component(cpn_id)
  361. if not cpn:
  362. raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
  363. return cpn["obj"].output(var_nm)
  364. def get_tenant_id(self):
  365. return self._tenant_id
  366. def get_history(self, window_size):
  367. convs = []
  368. if window_size <= 0:
  369. return convs
  370. for role, obj in self.history[window_size * -1:]:
  371. if isinstance(obj, dict):
  372. convs.append({"role": role, "content": obj.get("content", "")})
  373. else:
  374. convs.append({"role": role, "content": str(obj)})
  375. return convs
  376. def add_user_input(self, question):
  377. self.history.append(("user", question))
  378. def _find_loop(self, max_loops=6):
  379. path = self.path[-1][::-1]
  380. if len(path) < 2:
  381. return False
  382. for i in range(len(path)):
  383. if path[i].lower().find("answer") == 0 or path[i].lower().find("iterationitem") == 0:
  384. path = path[:i]
  385. break
  386. if len(path) < 2:
  387. return False
  388. for loc in range(2, len(path) // 2):
  389. pat = ",".join(path[0:loc])
  390. path_str = ",".join(path)
  391. if len(pat) >= len(path_str):
  392. return False
  393. loop = max_loops
  394. while path_str.find(pat) == 0 and loop >= 0:
  395. loop -= 1
  396. if len(pat)+1 >= len(path_str):
  397. return False
  398. path_str = path_str[len(pat)+1:]
  399. if loop < 0:
  400. pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
  401. return pat + " => " + pat
  402. return False
  403. def get_prologue(self):
  404. return self.components["begin"]["obj"]._param.prologue
  405. def set_global_param(self, **kwargs):
  406. self.globals.update(kwargs)
  407. def get_preset_param(self):
  408. return self.components["begin"]["obj"]._param.inputs
  409. def get_component_input_elements(self, cpnnm):
  410. return self.components[cpnnm]["obj"].get_input_elements()
  411. def get_files(self, files: Union[None, list[dict]]) -> list[str]:
  412. if not files:
  413. return []
  414. def image_to_base64(file):
  415. return "data:{};base64,{}".format(file["mime_type"],
  416. base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
  417. exe = ThreadPoolExecutor(max_workers=5)
  418. threads = []
  419. for file in files:
  420. if file["mime_type"].find("image") >=0:
  421. threads.append(exe.submit(image_to_base64, file))
  422. continue
  423. threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
  424. return [th.result() for th in threads]
  425. def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any):
  426. agent_ids = agent_id.split("-->")
  427. agent_name = self.get_component_name(agent_ids[0])
  428. path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:])
  429. try:
  430. bin = REDIS_CONN.get(f"{self.task_id}-{self.message_id}-logs")
  431. if bin:
  432. obj = json.loads(bin.encode("utf-8"))
  433. if obj[-1]["component_id"] == agent_ids[0]:
  434. obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result})
  435. else:
  436. obj.append({
  437. "component_id": agent_ids[0],
  438. "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}]
  439. })
  440. else:
  441. obj = [{
  442. "component_id": agent_ids[0],
  443. "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}]
  444. }]
  445. REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10)
  446. except Exception as e:
  447. logging.exception(e)
  448. def add_refernce(self, chunks: list[object], doc_infos: list[object]):
  449. if not self.retrieval:
  450. self.retrieval = [{"chunks": {}, "doc_aggs": {}}]
  451. r = self.retrieval[-1]
  452. for ck in chunks_format({"chunks": chunks}):
  453. cid = hash_str2int(ck["id"], 100)
  454. if cid not in r:
  455. r["chunks"][cid] = ck
  456. for doc in doc_infos:
  457. if doc["doc_name"] not in r:
  458. r["doc_aggs"][doc["doc_name"]] = doc
  459. def get_reference(self):
  460. if not self.retrieval:
  461. return {"chunks": {}, "doc_aggs": {}}
  462. return self.retrieval[-1]
  463. def add_memory(self, user:str, assist:str, summ: str):
  464. self.memory.append((user, assist, summ))
  465. def get_memory(self) -> list[Tuple]:
  466. return self.memory
  467. def get_component_thoughts(self, cpn_id) -> str:
  468. return self.components.get(cpn_id)["obj"].thoughts()