# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import base64 import json import logging import time from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from functools import partial from typing import Any, Union, Tuple from agent.component import component_class from agent.component.base import ComponentBase from api.db.services.file_service import FileService from api.utils import get_uuid, hash_str2int from rag.prompts.prompts import chunks_format from rag.utils.redis_conn import REDIS_CONN class Canvas: """ dsl = { "components": { "begin": { "obj":{ "component_name": "Begin", "params": {}, }, "downstream": ["answer_0"], "upstream": [], }, "retrieval_0": { "obj": { "component_name": "Retrieval", "params": {} }, "downstream": ["generate_0"], "upstream": ["answer_0"], }, "generate_0": { "obj": { "component_name": "Generate", "params": {} }, "downstream": ["answer_0"], "upstream": ["retrieval_0"], } }, "history": [], "path": ["begin"], "retrieval": {"chunks": [], "doc_aggs": []}, "globals": { "sys.query": "", "sys.user_id": tenant_id, "sys.conversation_turns": 0, "sys.files": [] } } """ def __init__(self, dsl: str, tenant_id=None, task_id=None): self.path = [] self.history = [] self.components = {} self.error = "" self.globals = { "sys.query": "", "sys.user_id": tenant_id, "sys.conversation_turns": 0, "sys.files": [] } self.dsl = json.loads(dsl) if dsl else { "components": { "begin": { "obj": { "component_name": "Begin", "params": { "prologue": "Hi there!" } }, "downstream": [], "upstream": [], "parent_id": "" } }, "history": [], "path": [], "retrieval": [], "globals": { "sys.query": "", "sys.user_id": "", "sys.conversation_turns": 0, "sys.files": [] } } self._tenant_id = tenant_id self.task_id = task_id if task_id else get_uuid() self.load() def load(self): self.components = self.dsl["components"] cpn_nms = set([]) for k, cpn in self.components.items(): cpn_nms.add(cpn["obj"]["component_name"]) assert "Begin" in cpn_nms, "There have to be an 'Begin' component." for k, cpn in self.components.items(): cpn_nms.add(cpn["obj"]["component_name"]) param = component_class(cpn["obj"]["component_name"] + "Param")() param.update(cpn["obj"]["params"]) try: param.check() except Exception as e: raise ValueError(self.get_component_name(k) + f": {e}") cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param) self.path = self.dsl["path"] self.history = self.dsl["history"] self.globals = self.dsl["globals"] self.retrieval = self.dsl["retrieval"] self.memory = self.dsl.get("memory", []) def __str__(self): self.dsl["path"] = self.path self.dsl["history"] = self.history self.dsl["globals"] = self.globals self.dsl["task_id"] = self.task_id self.dsl["retrieval"] = self.retrieval self.dsl["memory"] = self.memory dsl = { "components": {} } for k in self.dsl.keys(): if k in ["components"]: continue dsl[k] = deepcopy(self.dsl[k]) for k, cpn in self.components.items(): if k not in dsl["components"]: dsl["components"][k] = {} for c in cpn.keys(): if c == "obj": dsl["components"][k][c] = json.loads(str(cpn["obj"])) continue dsl["components"][k][c] = deepcopy(cpn[c]) return json.dumps(dsl, ensure_ascii=False) def reset(self, mem=False): self.path = [] if not mem: self.history = [] self.retrieval = [] self.memory = [] for k, cpn in self.components.items(): self.components[k]["obj"].reset() for k in self.globals.keys(): if isinstance(self.globals[k], str): self.globals[k] = "" elif isinstance(self.globals[k], int): self.globals[k] = 0 elif isinstance(self.globals[k], float): self.globals[k] = 0 elif isinstance(self.globals[k], list): self.globals[k] = [] elif isinstance(self.globals[k], dict): self.globals[k] = {} else: self.globals[k] = None try: REDIS_CONN.delete(f"{self.task_id}-logs") except Exception as e: logging.exception(e) def get_component_name(self, cid): for n in self.dsl.get("graph", {}).get("nodes", []): if cid == n["id"]: return n["data"]["name"] return "" def run(self, **kwargs): st = time.perf_counter() self.message_id = get_uuid() created_at = int(time.time()) self.add_user_input(kwargs.get("query")) for k in kwargs.keys(): if k in ["query", "user_id", "files"] and kwargs[k]: if k == "files": self.globals[f"sys.{k}"] = self.get_files(kwargs[k]) else: self.globals[f"sys.{k}"] = kwargs[k] if not self.globals["sys.conversation_turns"] : self.globals["sys.conversation_turns"] = 0 self.globals["sys.conversation_turns"] += 1 def decorate(event, dt): nonlocal created_at return { "event": event, #"conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115", "message_id": self.message_id, "created_at": created_at, "task_id": self.task_id, "data": dt } if not self.path or self.path[-1].lower().find("userfillup") < 0: self.path.append("begin") self.retrieval.append({"chunks": [], "doc_aggs": []}) yield decorate("workflow_started", {"inputs": kwargs.get("inputs")}) self.retrieval.append({"chunks": {}, "doc_aggs": {}}) def _run_batch(f, t): with ThreadPoolExecutor(max_workers=5) as executor: thr = [] for i in range(f, t): cpn = self.get_component_obj(self.path[i]) if cpn.component_name.lower() in ["begin", "userfillup"]: thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {}))) else: thr.append(executor.submit(cpn.invoke, **cpn.get_input())) for t in thr: t.result() def _node_finished(cpn_obj): return decorate("node_finished",{ "inputs": cpn_obj.get_input_values(), "outputs": cpn_obj.output(), "component_id": cpn_obj._id, "component_name": self.get_component_name(cpn_obj._id), "component_type": self.get_component_type(cpn_obj._id), "error": cpn_obj.error(), "elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"), "created_at": cpn_obj.output("_created_time"), }) self.error = "" idx = len(self.path) - 1 partials = [] while idx < len(self.path): to = len(self.path) for i in range(idx, to): yield decorate("node_started", { "inputs": None, "created_at": int(time.time()), "component_id": self.path[i], "component_name": self.get_component_name(self.path[i]), "component_type": self.get_component_type(self.path[i]), "thoughts": self.get_component_thoughts(self.path[i]) }) _run_batch(idx, to) # post processing of components invocation for i in range(idx, to): cpn = self.get_component(self.path[i]) cpn_obj = self.get_component_obj(self.path[i]) if cpn_obj.component_name.lower() == "message": if isinstance(cpn_obj.output("content"), partial): _m = "" for m in cpn_obj.output("content")(): if not m: continue if m == "": yield decorate("message", {"content": "", "start_to_think": True}) elif m == "": yield decorate("message", {"content": "", "end_to_think": True}) else: yield decorate("message", {"content": m}) _m += m cpn_obj.set_output("content", _m) else: yield decorate("message", {"content": cpn_obj.output("content")}) yield decorate("message_end", {"reference": self.get_reference()}) while partials: _cpn_obj = self.get_component_obj(partials[0]) if isinstance(_cpn_obj.output("content"), partial): break yield _node_finished(_cpn_obj) partials.pop(0) other_branch = False if cpn_obj.error(): ex = cpn_obj.exception_handler() if ex and ex["goto"]: self.path.extend(ex["goto"]) other_branch = True elif ex and ex["default_value"]: yield decorate("message", {"content": ex["default_value"]}) yield decorate("message_end", {}) else: self.error = cpn_obj.error() if cpn_obj.component_name.lower() != "iteration": if isinstance(cpn_obj.output("content"), partial): if self.error: cpn_obj.set_output("content", None) yield _node_finished(cpn_obj) else: partials.append(self.path[i]) else: yield _node_finished(cpn_obj) def _append_path(cpn_id): nonlocal other_branch if other_branch: return if self.path[-1] == cpn_id: return self.path.append(cpn_id) def _extend_path(cpn_ids): nonlocal other_branch if other_branch: return for cpn_id in cpn_ids: _append_path(cpn_id) if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end(): iter = cpn_obj.get_parent() yield _node_finished(iter) _extend_path(self.get_component(cpn["parent_id"])["downstream"]) elif cpn_obj.component_name.lower() in ["categorize", "switch"]: _extend_path(cpn_obj.output("_next")) elif cpn_obj.component_name.lower() == "iteration": _append_path(cpn_obj.get_start()) elif not cpn["downstream"] and cpn_obj.get_parent(): _append_path(cpn_obj.get_parent().get_start()) else: _extend_path(cpn["downstream"]) if self.error: logging.error(f"Runtime Error: {self.error}") break idx = to if any([self.get_component_obj(c).component_name.lower() == "userfillup" for c in self.path[idx:]]): path = [c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() == "userfillup"] path.extend([c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() != "userfillup"]) another_inputs = {} tips = "" for c in path: o = self.get_component_obj(c) if o.component_name.lower() == "userfillup": another_inputs.update(o.get_input_elements()) if o.get_param("enable_tips"): tips = o.get_param("tips") self.path = path yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips}) return self.path = self.path[:idx] if not self.error: yield decorate("workflow_finished", { "inputs": kwargs.get("inputs"), "outputs": self.get_component_obj(self.path[-1]).output(), "elapsed_time": time.perf_counter() - st, "created_at": st, }) self.history.append(("assistant", self.get_component_obj(self.path[-1]).output())) def get_component(self, cpn_id) -> Union[None, dict[str, Any]]: return self.components.get(cpn_id) def get_component_obj(self, cpn_id) -> ComponentBase: return self.components.get(cpn_id)["obj"] def get_component_type(self, cpn_id) -> str: return self.components.get(cpn_id)["obj"].component_name def get_component_input_form(self, cpn_id) -> dict: return self.components.get(cpn_id)["obj"].get_input_form() def is_reff(self, exp: str) -> bool: exp = exp.strip("{").strip("}") if exp.find("@") < 0: return exp in self.globals arr = exp.split("@") if len(arr) != 2: return False if self.get_component(arr[0]) is None: return False return True def get_variable_value(self, exp: str) -> Any: exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}") if exp.find("@") < 0: return self.globals[exp] cpn_id, var_nm = exp.split("@") cpn = self.get_component(cpn_id) if not cpn: raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'") return cpn["obj"].output(var_nm) def get_tenant_id(self): return self._tenant_id def get_history(self, window_size): convs = [] if window_size <= 0: return convs for role, obj in self.history[window_size * -1:]: if isinstance(obj, dict): convs.append({"role": role, "content": obj.get("content", "")}) else: convs.append({"role": role, "content": str(obj)}) return convs def add_user_input(self, question): self.history.append(("user", question)) def _find_loop(self, max_loops=6): path = self.path[-1][::-1] if len(path) < 2: return False for i in range(len(path)): if path[i].lower().find("answer") == 0 or path[i].lower().find("iterationitem") == 0: path = path[:i] break if len(path) < 2: return False for loc in range(2, len(path) // 2): pat = ",".join(path[0:loc]) path_str = ",".join(path) if len(pat) >= len(path_str): return False loop = max_loops while path_str.find(pat) == 0 and loop >= 0: loop -= 1 if len(pat)+1 >= len(path_str): return False path_str = path_str[len(pat)+1:] if loop < 0: pat = " => ".join([p.split(":")[0] for p in path[0:loc]]) return pat + " => " + pat return False def get_prologue(self): return self.components["begin"]["obj"]._param.prologue def set_global_param(self, **kwargs): self.globals.update(kwargs) def get_preset_param(self): return self.components["begin"]["obj"]._param.inputs def get_component_input_elements(self, cpnnm): return self.components[cpnnm]["obj"].get_input_elements() def get_files(self, files: Union[None, list[dict]]) -> list[str]: if not files: return [] def image_to_base64(file): return "data:{};base64,{}".format(file["mime_type"], base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) exe = ThreadPoolExecutor(max_workers=5) threads = [] for file in files: if file["mime_type"].find("image") >=0: threads.append(exe.submit(image_to_base64, file)) continue threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) return [th.result() for th in threads] def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any): agent_ids = agent_id.split("-->") agent_name = self.get_component_name(agent_ids[0]) path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:]) try: bin = REDIS_CONN.get(f"{self.task_id}-{self.message_id}-logs") if bin: obj = json.loads(bin.encode("utf-8")) if obj[-1]["component_id"] == agent_ids[0]: obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result}) else: obj.append({ "component_id": agent_ids[0], "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}] }) else: obj = [{ "component_id": agent_ids[0], "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}] }] REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10) except Exception as e: logging.exception(e) def add_refernce(self, chunks: list[object], doc_infos: list[object]): if not self.retrieval: self.retrieval = [{"chunks": {}, "doc_aggs": {}}] r = self.retrieval[-1] for ck in chunks_format({"chunks": chunks}): cid = hash_str2int(ck["id"], 100) if cid not in r: r["chunks"][cid] = ck for doc in doc_infos: if doc["doc_name"] not in r: r["doc_aggs"][doc["doc_name"]] = doc def get_reference(self): if not self.retrieval: return {"chunks": {}, "doc_aggs": {}} return self.retrieval[-1] def add_memory(self, user:str, assist:str, summ: str): self.memory.append((user, assist, summ)) def get_memory(self) -> list[Tuple]: return self.memory def get_component_thoughts(self, cpn_id) -> str: return self.components.get(cpn_id)["obj"].thoughts()