Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.


  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. import json
  18. import logging
  19. import time
  20. from concurrent.futures import ThreadPoolExecutor
  21. from copy import deepcopy
  22. from functools import partial
  23. from typing import Any, Union, Tuple
  24. from agent.component import component_class
  25. from agent.component.base import ComponentBase
  26. from api.db.services.file_service import FileService
  27. from api.utils import get_uuid, hash_str2int
  28. from rag.prompts.prompts import chunks_format
  29. from rag.utils.redis_conn import REDIS_CONN
  30. class Canvas:
  31. """
  32. dsl = {
  33. "components": {
  34. "begin": {
  35. "obj":{
  36. "component_name": "Begin",
  37. "params": {},
  38. },
  39. "downstream": ["answer_0"],
  40. "upstream": [],
  41. },
  42. "retrieval_0": {
  43. "obj": {
  44. "component_name": "Retrieval",
  45. "params": {}
  46. },
  47. "downstream": ["generate_0"],
  48. "upstream": ["answer_0"],
  49. },
  50. "generate_0": {
  51. "obj": {
  52. "component_name": "Generate",
  53. "params": {}
  54. },
  55. "downstream": ["answer_0"],
  56. "upstream": ["retrieval_0"],
  57. }
  58. },
  59. "history": [],
  60. "path": ["begin"],
  61. "retrieval": {"chunks": [], "doc_aggs": []},
  62. "globals": {
  63. "sys.query": "",
  64. "sys.user_id": tenant_id,
  65. "sys.conversation_turns": 0,
  66. "sys.files": []
  67. }
  68. }
  69. """
  70. def __init__(self, dsl: str, tenant_id=None, task_id=None):
  71. self.path = []
  72. self.history = []
  73. self.components = {}
  74. self.error = ""
  75. self.globals = {
  76. "sys.query": "",
  77. "sys.user_id": tenant_id,
  78. "sys.conversation_turns": 0,
  79. "sys.files": []
  80. }
  81. self.dsl = json.loads(dsl) if dsl else {
  82. "components": {
  83. "begin": {
  84. "obj": {
  85. "component_name": "Begin",
  86. "params": {
  87. "prologue": "Hi there!"
  88. }
  89. },
  90. "downstream": [],
  91. "upstream": [],
  92. "parent_id": ""
  93. }
  94. },
  95. "history": [],
  96. "path": [],
  97. "retrieval": [],
  98. "globals": {
  99. "sys.query": "",
  100. "sys.user_id": "",
  101. "sys.conversation_turns": 0,
  102. "sys.files": []
  103. }
  104. }
  105. self._tenant_id = tenant_id
  106. self.task_id = task_id if task_id else get_uuid()
  107. self.load()
  108. def load(self):
  109. self.components = self.dsl["components"]
  110. cpn_nms = set([])
  111. for k, cpn in self.components.items():
  112. cpn_nms.add(cpn["obj"]["component_name"])
  113. assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
  114. for k, cpn in self.components.items():
  115. cpn_nms.add(cpn["obj"]["component_name"])
  116. param = component_class(cpn["obj"]["component_name"] + "Param")()
  117. param.update(cpn["obj"]["params"])
  118. try:
  119. param.check()
  120. except Exception as e:
  121. raise ValueError(self.get_component_name(k) + f": {e}")
  122. cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
  123. self.path = self.dsl["path"]
  124. self.history = self.dsl["history"]
  125. self.globals = self.dsl["globals"]
  126. self.retrieval = self.dsl["retrieval"]
  127. self.memory = self.dsl.get("memory", [])
  128. def __str__(self):
  129. self.dsl["path"] = self.path
  130. self.dsl["history"] = self.history
  131. self.dsl["globals"] = self.globals
  132. self.dsl["task_id"] = self.task_id
  133. self.dsl["retrieval"] = self.retrieval
  134. self.dsl["memory"] = self.memory
  135. dsl = {
  136. "components": {}
  137. }
  138. for k in self.dsl.keys():
  139. if k in ["components"]:
  140. continue
  141. dsl[k] = deepcopy(self.dsl[k])
  142. for k, cpn in self.components.items():
  143. if k not in dsl["components"]:
  144. dsl["components"][k] = {}
  145. for c in cpn.keys():
  146. if c == "obj":
  147. dsl["components"][k][c] = json.loads(str(cpn["obj"]))
  148. continue
  149. dsl["components"][k][c] = deepcopy(cpn[c])
  150. return json.dumps(dsl, ensure_ascii=False)
  151. def reset(self, mem=False):
  152. self.path = []
  153. if not mem:
  154. self.history = []
  155. self.retrieval = []
  156. self.memory = []
  157. for k, cpn in self.components.items():
  158. self.components[k]["obj"].reset()
  159. for k in self.globals.keys():
  160. if isinstance(self.globals[k], str):
  161. self.globals[k] = ""
  162. elif isinstance(self.globals[k], int):
  163. self.globals[k] = 0
  164. elif isinstance(self.globals[k], float):
  165. self.globals[k] = 0
  166. elif isinstance(self.globals[k], list):
  167. self.globals[k] = []
  168. elif isinstance(self.globals[k], dict):
  169. self.globals[k] = {}
  170. else:
  171. self.globals[k] = None
  172. try:
  173. REDIS_CONN.delete(f"{self.task_id}-logs")
  174. except Exception as e:
  175. logging.exception(e)
  176. def get_component_name(self, cid):
  177. for n in self.dsl.get("graph", {}).get("nodes", []):
  178. if cid == n["id"]:
  179. return n["data"]["name"]
  180. return ""
  181. def run(self, **kwargs):
  182. st = time.perf_counter()
  183. self.message_id = get_uuid()
  184. created_at = int(time.time())
  185. self.add_user_input(kwargs.get("query"))
  186. for k in kwargs.keys():
  187. if k in ["query", "user_id", "files"] and kwargs[k]:
  188. if k == "files":
  189. self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
  190. else:
  191. self.globals[f"sys.{k}"] = kwargs[k]
  192. if not self.globals["sys.conversation_turns"] :
  193. self.globals["sys.conversation_turns"] = 0
  194. self.globals["sys.conversation_turns"] += 1
  195. def decorate(event, dt):
  196. nonlocal created_at
  197. return {
  198. "event": event,
  199. #"conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115",
  200. "message_id": self.message_id,
  201. "created_at": created_at,
  202. "task_id": self.task_id,
  203. "data": dt
  204. }
  205. if not self.path or self.path[-1].lower().find("userfillup") < 0:
  206. self.path.append("begin")
  207. self.retrieval.append({"chunks": [], "doc_aggs": []})
  208. yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
  209. self.retrieval.append({"chunks": {}, "doc_aggs": {}})
  210. def _run_batch(f, t):
  211. with ThreadPoolExecutor(max_workers=5) as executor:
  212. thr = []
  213. for i in range(f, t):
  214. cpn = self.get_component_obj(self.path[i])
  215. if cpn.component_name.lower() in ["begin", "userfillup"]:
  216. thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
  217. else:
  218. thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
  219. for t in thr:
  220. t.result()
  221. def _node_finished(cpn_obj):
  222. return decorate("node_finished",{
  223. "inputs": cpn_obj.get_input_values(),
  224. "outputs": cpn_obj.output(),
  225. "component_id": cpn_obj._id,
  226. "component_name": self.get_component_name(cpn_obj._id),
  227. "component_type": self.get_component_type(cpn_obj._id),
  228. "error": cpn_obj.error(),
  229. "elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"),
  230. "created_at": cpn_obj.output("_created_time"),
  231. })
  232. self.error = ""
  233. idx = len(self.path) - 1
  234. partials = []
  235. while idx < len(self.path):
  236. to = len(self.path)
  237. for i in range(idx, to):
  238. yield decorate("node_started", {
  239. "inputs": None, "created_at": int(time.time()),
  240. "component_id": self.path[i],
  241. "component_name": self.get_component_name(self.path[i]),
  242. "component_type": self.get_component_type(self.path[i]),
  243. "thoughts": self.get_component_thoughts(self.path[i])
  244. })
  245. _run_batch(idx, to)
  246. # post processing of components invocation
  247. for i in range(idx, to):
  248. cpn = self.get_component(self.path[i])
  249. cpn_obj = self.get_component_obj(self.path[i])
  250. if cpn_obj.component_name.lower() == "message":
  251. if isinstance(cpn_obj.output("content"), partial):
  252. _m = ""
  253. for m in cpn_obj.output("content")():
  254. if not m:
  255. continue
  256. if m == "<think>":
  257. yield decorate("message", {"content": "", "start_to_think": True})
  258. elif m == "</think>":
  259. yield decorate("message", {"content": "", "end_to_think": True})
  260. else:
  261. yield decorate("message", {"content": m})
  262. _m += m
  263. cpn_obj.set_output("content", _m)
  264. else:
  265. yield decorate("message", {"content": cpn_obj.output("content")})
  266. yield decorate("message_end", {"reference": self.get_reference()})
  267. while partials:
  268. _cpn_obj = self.get_component_obj(partials[0])
  269. if isinstance(_cpn_obj.output("content"), partial):
  270. break
  271. yield _node_finished(_cpn_obj)
  272. partials.pop(0)
  273. other_branch = False
  274. if cpn_obj.error():
  275. ex = cpn_obj.exception_handler()
  276. if ex and ex["goto"]:
  277. self.path.extend(ex["goto"])
  278. other_branch = True
  279. elif ex and ex["default_value"]:
  280. yield decorate("message", {"content": ex["default_value"]})
  281. yield decorate("message_end", {})
  282. else:
  283. self.error = cpn_obj.error()
  284. if cpn_obj.component_name.lower() != "iteration":
  285. if isinstance(cpn_obj.output("content"), partial):
  286. if self.error:
  287. cpn_obj.set_output("content", None)
  288. yield _node_finished(cpn_obj)
  289. else:
  290. partials.append(self.path[i])
  291. else:
  292. yield _node_finished(cpn_obj)
  293. def _append_path(cpn_id):
  294. nonlocal other_branch
  295. if other_branch:
  296. return
  297. if self.path[-1] == cpn_id:
  298. return
  299. self.path.append(cpn_id)
  300. def _extend_path(cpn_ids):
  301. nonlocal other_branch
  302. if other_branch:
  303. return
  304. for cpn_id in cpn_ids:
  305. _append_path(cpn_id)
  306. if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end():
  307. iter = cpn_obj.get_parent()
  308. yield _node_finished(iter)
  309. _extend_path(self.get_component(cpn["parent_id"])["downstream"])
  310. elif cpn_obj.component_name.lower() in ["categorize", "switch"]:
  311. _extend_path(cpn_obj.output("_next"))
  312. elif cpn_obj.component_name.lower() == "iteration":
  313. _append_path(cpn_obj.get_start())
  314. elif not cpn["downstream"] and cpn_obj.get_parent():
  315. _append_path(cpn_obj.get_parent().get_start())
  316. else:
  317. _extend_path(cpn["downstream"])
  318. if self.error:
  319. logging.error(f"Runtime Error: {self.error}")
  320. break
  321. idx = to
  322. if any([self.get_component_obj(c).component_name.lower() == "userfillup" for c in self.path[idx:]]):
  323. path = [c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() == "userfillup"]
  324. path.extend([c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() != "userfillup"])
  325. another_inputs = {}
  326. tips = ""
  327. for c in path:
  328. o = self.get_component_obj(c)
  329. if o.component_name.lower() == "userfillup":
  330. another_inputs.update(o.get_input_elements())
  331. if o.get_param("enable_tips"):
  332. tips = o.get_param("tips")
  333. self.path = path
  334. yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips})
  335. return
  336. self.path = self.path[:idx]
  337. if not self.error:
  338. yield decorate("workflow_finished",
  339. {
  340. "inputs": kwargs.get("inputs"),
  341. "outputs": self.get_component_obj(self.path[-1]).output(),
  342. "elapsed_time": time.perf_counter() - st,
  343. "created_at": st,
  344. })
  345. self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
  346. def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
  347. return self.components.get(cpn_id)
  348. def get_component_obj(self, cpn_id) -> ComponentBase:
  349. return self.components.get(cpn_id)["obj"]
  350. def get_component_type(self, cpn_id) -> str:
  351. return self.components.get(cpn_id)["obj"].component_name
  352. def get_component_input_form(self, cpn_id) -> dict:
  353. return self.components.get(cpn_id)["obj"].get_input_form()
  354. def is_reff(self, exp: str) -> bool:
  355. exp = exp.strip("{").strip("}")
  356. if exp.find("@") < 0:
  357. return exp in self.globals
  358. arr = exp.split("@")
  359. if len(arr) != 2:
  360. return False
  361. if self.get_component(arr[0]) is None:
  362. return False
  363. return True
  364. def get_variable_value(self, exp: str) -> Any:
  365. exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
  366. if exp.find("@") < 0:
  367. return self.globals[exp]
  368. cpn_id, var_nm = exp.split("@")
  369. cpn = self.get_component(cpn_id)
  370. if not cpn:
  371. raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
  372. return cpn["obj"].output(var_nm)
  373. def get_tenant_id(self):
  374. return self._tenant_id
  375. def get_history(self, window_size):
  376. convs = []
  377. if window_size <= 0:
  378. return convs
  379. for role, obj in self.history[window_size * -1:]:
  380. if isinstance(obj, dict):
  381. convs.append({"role": role, "content": obj.get("content", "")})
  382. else:
  383. convs.append({"role": role, "content": str(obj)})
  384. return convs
  385. def add_user_input(self, question):
  386. self.history.append(("user", question))
  387. def _find_loop(self, max_loops=6):
  388. path = self.path[-1][::-1]
  389. if len(path) < 2:
  390. return False
  391. for i in range(len(path)):
  392. if path[i].lower().find("answer") == 0 or path[i].lower().find("iterationitem") == 0:
  393. path = path[:i]
  394. break
  395. if len(path) < 2:
  396. return False
  397. for loc in range(2, len(path) // 2):
  398. pat = ",".join(path[0:loc])
  399. path_str = ",".join(path)
  400. if len(pat) >= len(path_str):
  401. return False
  402. loop = max_loops
  403. while path_str.find(pat) == 0 and loop >= 0:
  404. loop -= 1
  405. if len(pat)+1 >= len(path_str):
  406. return False
  407. path_str = path_str[len(pat)+1:]
  408. if loop < 0:
  409. pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
  410. return pat + " => " + pat
  411. return False
  412. def get_prologue(self):
  413. return self.components["begin"]["obj"]._param.prologue
  414. def set_global_param(self, **kwargs):
  415. self.globals.update(kwargs)
  416. def get_preset_param(self):
  417. return self.components["begin"]["obj"]._param.inputs
  418. def get_component_input_elements(self, cpnnm):
  419. return self.components[cpnnm]["obj"].get_input_elements()
  420. def get_files(self, files: Union[None, list[dict]]) -> list[str]:
  421. if not files:
  422. return []
  423. def image_to_base64(file):
  424. return "data:{};base64,{}".format(file["mime_type"],
  425. base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
  426. exe = ThreadPoolExecutor(max_workers=5)
  427. threads = []
  428. for file in files:
  429. if file["mime_type"].find("image") >=0:
  430. threads.append(exe.submit(image_to_base64, file))
  431. continue
  432. threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
  433. return [th.result() for th in threads]
  434. def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
  435. agent_ids = agent_id.split("-->")
  436. agent_name = self.get_component_name(agent_ids[0])
  437. path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:])
  438. try:
  439. bin = REDIS_CONN.get(f"{self.task_id}-{self.message_id}-logs")
  440. if bin:
  441. obj = json.loads(bin.encode("utf-8"))
  442. if obj[-1]["component_id"] == agent_ids[0]:
  443. obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time})
  444. else:
  445. obj.append({
  446. "component_id": agent_ids[0],
  447. "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]
  448. })
  449. else:
  450. obj = [{
  451. "component_id": agent_ids[0],
  452. "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]
  453. }]
  454. REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10)
  455. except Exception as e:
  456. logging.exception(e)
  457. def add_refernce(self, chunks: list[object], doc_infos: list[object]):
  458. if not self.retrieval:
  459. self.retrieval = [{"chunks": {}, "doc_aggs": {}}]
  460. r = self.retrieval[-1]
  461. for ck in chunks_format({"chunks": chunks}):
  462. cid = hash_str2int(ck["id"], 100)
  463. if cid not in r:
  464. r["chunks"][cid] = ck
  465. for doc in doc_infos:
  466. if doc["doc_name"] not in r:
  467. r["doc_aggs"][doc["doc_name"]] = doc
  468. def get_reference(self):
  469. if not self.retrieval:
  470. return {"chunks": {}, "doc_aggs": {}}
  471. return self.retrieval[-1]
  472. def add_memory(self, user:str, assist:str, summ: str):
  473. self.memory.append((user, assist, summ))
  474. def get_memory(self) -> list[Tuple]:
  475. return self.memory
  476. def get_component_thoughts(self, cpn_id) -> str:
  477. return self.components.get(cpn_id)["obj"].thoughts()