| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538 |
- #
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import base64
- import json
- import logging
- import time
- from concurrent.futures import ThreadPoolExecutor
- from copy import deepcopy
- from functools import partial
- from typing import Any, Union, Tuple
-
- from agent.component import component_class
- from agent.component.base import ComponentBase
- from api.db.services.file_service import FileService
- from api.utils import get_uuid, hash_str2int
- from rag.prompts.prompts import chunks_format
- from rag.utils.redis_conn import REDIS_CONN
-
-
- class Canvas:
- """
- dsl = {
- "components": {
- "begin": {
- "obj":{
- "component_name": "Begin",
- "params": {},
- },
- "downstream": ["answer_0"],
- "upstream": [],
- },
- "retrieval_0": {
- "obj": {
- "component_name": "Retrieval",
- "params": {}
- },
- "downstream": ["generate_0"],
- "upstream": ["answer_0"],
- },
- "generate_0": {
- "obj": {
- "component_name": "Generate",
- "params": {}
- },
- "downstream": ["answer_0"],
- "upstream": ["retrieval_0"],
- }
- },
- "history": [],
- "path": ["begin"],
- "retrieval": {"chunks": [], "doc_aggs": []},
- "globals": {
- "sys.query": "",
- "sys.user_id": tenant_id,
- "sys.conversation_turns": 0,
- "sys.files": []
- }
- }
- """
-
- def __init__(self, dsl: str, tenant_id=None, task_id=None):
- self.path = []
- self.history = []
- self.components = {}
- self.error = ""
- self.globals = {
- "sys.query": "",
- "sys.user_id": tenant_id,
- "sys.conversation_turns": 0,
- "sys.files": []
- }
- self.dsl = json.loads(dsl) if dsl else {
- "components": {
- "begin": {
- "obj": {
- "component_name": "Begin",
- "params": {
- "prologue": "Hi there!"
- }
- },
- "downstream": [],
- "upstream": [],
- "parent_id": ""
- }
- },
- "history": [],
- "path": [],
- "retrieval": [],
- "globals": {
- "sys.query": "",
- "sys.user_id": "",
- "sys.conversation_turns": 0,
- "sys.files": []
- }
- }
- self._tenant_id = tenant_id
- self.task_id = task_id if task_id else get_uuid()
- self.load()
-
- def load(self):
- self.components = self.dsl["components"]
- cpn_nms = set([])
- for k, cpn in self.components.items():
- cpn_nms.add(cpn["obj"]["component_name"])
-
- assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
-
- for k, cpn in self.components.items():
- cpn_nms.add(cpn["obj"]["component_name"])
- param = component_class(cpn["obj"]["component_name"] + "Param")()
- param.update(cpn["obj"]["params"])
- try:
- param.check()
- except Exception as e:
- raise ValueError(self.get_component_name(k) + f": {e}")
-
- cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
-
- self.path = self.dsl["path"]
- self.history = self.dsl["history"]
- self.globals = self.dsl["globals"]
- self.retrieval = self.dsl["retrieval"]
- self.memory = self.dsl.get("memory", [])
-
- def __str__(self):
- self.dsl["path"] = self.path
- self.dsl["history"] = self.history
- self.dsl["globals"] = self.globals
- self.dsl["task_id"] = self.task_id
- self.dsl["retrieval"] = self.retrieval
- self.dsl["memory"] = self.memory
- dsl = {
- "components": {}
- }
- for k in self.dsl.keys():
- if k in ["components"]:
- continue
- dsl[k] = deepcopy(self.dsl[k])
-
- for k, cpn in self.components.items():
- if k not in dsl["components"]:
- dsl["components"][k] = {}
- for c in cpn.keys():
- if c == "obj":
- dsl["components"][k][c] = json.loads(str(cpn["obj"]))
- continue
- dsl["components"][k][c] = deepcopy(cpn[c])
- return json.dumps(dsl, ensure_ascii=False)
-
- def reset(self, mem=False):
- self.path = []
- if not mem:
- self.history = []
- self.retrieval = []
- self.memory = []
- for k, cpn in self.components.items():
- self.components[k]["obj"].reset()
-
- for k in self.globals.keys():
- if isinstance(self.globals[k], str):
- self.globals[k] = ""
- elif isinstance(self.globals[k], int):
- self.globals[k] = 0
- elif isinstance(self.globals[k], float):
- self.globals[k] = 0
- elif isinstance(self.globals[k], list):
- self.globals[k] = []
- elif isinstance(self.globals[k], dict):
- self.globals[k] = {}
- else:
- self.globals[k] = None
-
- try:
- REDIS_CONN.delete(f"{self.task_id}-logs")
- except Exception as e:
- logging.exception(e)
-
- def get_component_name(self, cid):
- for n in self.dsl.get("graph", {}).get("nodes", []):
- if cid == n["id"]:
- return n["data"]["name"]
- return ""
-
- def run(self, **kwargs):
- st = time.perf_counter()
- self.message_id = get_uuid()
- created_at = int(time.time())
- self.add_user_input(kwargs.get("query"))
-
- for k in kwargs.keys():
- if k in ["query", "user_id", "files"] and kwargs[k]:
- if k == "files":
- self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
- else:
- self.globals[f"sys.{k}"] = kwargs[k]
- if not self.globals["sys.conversation_turns"] :
- self.globals["sys.conversation_turns"] = 0
- self.globals["sys.conversation_turns"] += 1
-
- def decorate(event, dt):
- nonlocal created_at
- return {
- "event": event,
- #"conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115",
- "message_id": self.message_id,
- "created_at": created_at,
- "task_id": self.task_id,
- "data": dt
- }
-
- if not self.path or self.path[-1].lower().find("userfillup") < 0:
- self.path.append("begin")
- self.retrieval.append({"chunks": [], "doc_aggs": []})
-
- yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
- self.retrieval.append({"chunks": {}, "doc_aggs": {}})
-
- def _run_batch(f, t):
- with ThreadPoolExecutor(max_workers=5) as executor:
- thr = []
- for i in range(f, t):
- cpn = self.get_component_obj(self.path[i])
- if cpn.component_name.lower() in ["begin", "userfillup"]:
- thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
- else:
- thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
- for t in thr:
- t.result()
-
- def _node_finished(cpn_obj):
- return decorate("node_finished",{
- "inputs": cpn_obj.get_input_values(),
- "outputs": cpn_obj.output(),
- "component_id": cpn_obj._id,
- "component_name": self.get_component_name(cpn_obj._id),
- "component_type": self.get_component_type(cpn_obj._id),
- "error": cpn_obj.error(),
- "elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"),
- "created_at": cpn_obj.output("_created_time"),
- })
-
- self.error = ""
- idx = len(self.path) - 1
- partials = []
- while idx < len(self.path):
- to = len(self.path)
- for i in range(idx, to):
- yield decorate("node_started", {
- "inputs": None, "created_at": int(time.time()),
- "component_id": self.path[i],
- "component_name": self.get_component_name(self.path[i]),
- "component_type": self.get_component_type(self.path[i]),
- "thoughts": self.get_component_thoughts(self.path[i])
- })
- _run_batch(idx, to)
-
- # post processing of components invocation
- for i in range(idx, to):
- cpn = self.get_component(self.path[i])
- cpn_obj = self.get_component_obj(self.path[i])
- if cpn_obj.component_name.lower() == "message":
- if isinstance(cpn_obj.output("content"), partial):
- _m = ""
- for m in cpn_obj.output("content")():
- if not m:
- continue
- if m == "<think>":
- yield decorate("message", {"content": "", "start_to_think": True})
- elif m == "</think>":
- yield decorate("message", {"content": "", "end_to_think": True})
- else:
- yield decorate("message", {"content": m})
- _m += m
- cpn_obj.set_output("content", _m)
- else:
- yield decorate("message", {"content": cpn_obj.output("content")})
- yield decorate("message_end", {"reference": self.get_reference()})
-
- while partials:
- _cpn_obj = self.get_component_obj(partials[0])
- if isinstance(_cpn_obj.output("content"), partial):
- break
- yield _node_finished(_cpn_obj)
- partials.pop(0)
-
- other_branch = False
- if cpn_obj.error():
- ex = cpn_obj.exception_handler()
- if ex and ex["goto"]:
- self.path.extend(ex["goto"])
- other_branch = True
- elif ex and ex["default_value"]:
- yield decorate("message", {"content": ex["default_value"]})
- yield decorate("message_end", {})
- else:
- self.error = cpn_obj.error()
-
- if cpn_obj.component_name.lower() != "iteration":
- if isinstance(cpn_obj.output("content"), partial):
- if self.error:
- cpn_obj.set_output("content", None)
- yield _node_finished(cpn_obj)
- else:
- partials.append(self.path[i])
- else:
- yield _node_finished(cpn_obj)
-
- def _append_path(cpn_id):
- nonlocal other_branch
- if other_branch:
- return
- if self.path[-1] == cpn_id:
- return
- self.path.append(cpn_id)
-
- def _extend_path(cpn_ids):
- nonlocal other_branch
- if other_branch:
- return
- for cpn_id in cpn_ids:
- _append_path(cpn_id)
-
- if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end():
- iter = cpn_obj.get_parent()
- yield _node_finished(iter)
- _extend_path(self.get_component(cpn["parent_id"])["downstream"])
- elif cpn_obj.component_name.lower() in ["categorize", "switch"]:
- _extend_path(cpn_obj.output("_next"))
- elif cpn_obj.component_name.lower() == "iteration":
- _append_path(cpn_obj.get_start())
- elif not cpn["downstream"] and cpn_obj.get_parent():
- _append_path(cpn_obj.get_parent().get_start())
- else:
- _extend_path(cpn["downstream"])
-
- if self.error:
- logging.error(f"Runtime Error: {self.error}")
- break
- idx = to
-
- if any([self.get_component_obj(c).component_name.lower() == "userfillup" for c in self.path[idx:]]):
- path = [c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() == "userfillup"]
- path.extend([c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() != "userfillup"])
- another_inputs = {}
- tips = ""
- for c in path:
- o = self.get_component_obj(c)
- if o.component_name.lower() == "userfillup":
- another_inputs.update(o.get_input_elements())
- if o.get_param("enable_tips"):
- tips = o.get_param("tips")
- self.path = path
- yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips})
- return
-
- self.path = self.path[:idx]
- if not self.error:
- yield decorate("workflow_finished",
- {
- "inputs": kwargs.get("inputs"),
- "outputs": self.get_component_obj(self.path[-1]).output(),
- "elapsed_time": time.perf_counter() - st,
- "created_at": st,
- })
- self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
-
- def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
- return self.components.get(cpn_id)
-
- def get_component_obj(self, cpn_id) -> ComponentBase:
- return self.components.get(cpn_id)["obj"]
-
- def get_component_type(self, cpn_id) -> str:
- return self.components.get(cpn_id)["obj"].component_name
-
- def get_component_input_form(self, cpn_id) -> dict:
- return self.components.get(cpn_id)["obj"].get_input_form()
-
- def is_reff(self, exp: str) -> bool:
- exp = exp.strip("{").strip("}")
- if exp.find("@") < 0:
- return exp in self.globals
- arr = exp.split("@")
- if len(arr) != 2:
- return False
- if self.get_component(arr[0]) is None:
- return False
- return True
-
- def get_variable_value(self, exp: str) -> Any:
- exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
- if exp.find("@") < 0:
- return self.globals[exp]
- cpn_id, var_nm = exp.split("@")
- cpn = self.get_component(cpn_id)
- if not cpn:
- raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
- return cpn["obj"].output(var_nm)
-
- def get_tenant_id(self):
- return self._tenant_id
-
- def get_history(self, window_size):
- convs = []
- if window_size <= 0:
- return convs
- for role, obj in self.history[window_size * -1:]:
- if isinstance(obj, dict):
- convs.append({"role": role, "content": obj.get("content", "")})
- else:
- convs.append({"role": role, "content": str(obj)})
- return convs
-
- def add_user_input(self, question):
- self.history.append(("user", question))
-
- def _find_loop(self, max_loops=6):
- path = self.path[-1][::-1]
- if len(path) < 2:
- return False
-
- for i in range(len(path)):
- if path[i].lower().find("answer") == 0 or path[i].lower().find("iterationitem") == 0:
- path = path[:i]
- break
-
- if len(path) < 2:
- return False
-
- for loc in range(2, len(path) // 2):
- pat = ",".join(path[0:loc])
- path_str = ",".join(path)
- if len(pat) >= len(path_str):
- return False
- loop = max_loops
- while path_str.find(pat) == 0 and loop >= 0:
- loop -= 1
- if len(pat)+1 >= len(path_str):
- return False
- path_str = path_str[len(pat)+1:]
- if loop < 0:
- pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
- return pat + " => " + pat
-
- return False
-
- def get_prologue(self):
- return self.components["begin"]["obj"]._param.prologue
-
- def set_global_param(self, **kwargs):
- self.globals.update(kwargs)
-
- def get_preset_param(self):
- return self.components["begin"]["obj"]._param.inputs
-
- def get_component_input_elements(self, cpnnm):
- return self.components[cpnnm]["obj"].get_input_elements()
-
- def get_files(self, files: Union[None, list[dict]]) -> list[str]:
- if not files:
- return []
- def image_to_base64(file):
- return "data:{};base64,{}".format(file["mime_type"],
- base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
- exe = ThreadPoolExecutor(max_workers=5)
- threads = []
- for file in files:
- if file["mime_type"].find("image") >=0:
- threads.append(exe.submit(image_to_base64, file))
- continue
- threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
- return [th.result() for th in threads]
-
- def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any):
- agent_ids = agent_id.split("-->")
- agent_name = self.get_component_name(agent_ids[0])
- path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:])
- try:
- bin = REDIS_CONN.get(f"{self.task_id}-{self.message_id}-logs")
- if bin:
- obj = json.loads(bin.encode("utf-8"))
- if obj[-1]["component_id"] == agent_ids[0]:
- obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result})
- else:
- obj.append({
- "component_id": agent_ids[0],
- "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}]
- })
- else:
- obj = [{
- "component_id": agent_ids[0],
- "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}]
- }]
- REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10)
- except Exception as e:
- logging.exception(e)
-
- def add_refernce(self, chunks: list[object], doc_infos: list[object]):
- if not self.retrieval:
- self.retrieval = [{"chunks": {}, "doc_aggs": {}}]
-
- r = self.retrieval[-1]
- for ck in chunks_format({"chunks": chunks}):
- cid = hash_str2int(ck["id"], 100)
- if cid not in r:
- r["chunks"][cid] = ck
-
- for doc in doc_infos:
- if doc["doc_name"] not in r:
- r["doc_aggs"][doc["doc_name"]] = doc
-
- def get_reference(self):
- if not self.retrieval:
- return {"chunks": {}, "doc_aggs": {}}
- return self.retrieval[-1]
-
- def add_memory(self, user:str, assist:str, summ: str):
- self.memory.append((user, assist, summ))
-
- def get_memory(self) -> list[Tuple]:
- return self.memory
-
- def get_component_thoughts(self, cpn_id) -> str:
- return self.components.get(cpn_id)["obj"].thoughts()
|