| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538 | 
							- #
 - #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
 - #
 - #  Licensed under the Apache License, Version 2.0 (the "License");
 - #  you may not use this file except in compliance with the License.
 - #  You may obtain a copy of the License at
 - #
 - #      http://www.apache.org/licenses/LICENSE-2.0
 - #
 - #  Unless required by applicable law or agreed to in writing, software
 - #  distributed under the License is distributed on an "AS IS" BASIS,
 - #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 - #  See the License for the specific language governing permissions and
 - #  limitations under the License.
 - #
 - import base64
 - import json
 - import logging
 - import time
 - from concurrent.futures import ThreadPoolExecutor
 - from copy import deepcopy
 - from functools import partial
 - from typing import Any, Union, Tuple
 - 
 - from agent.component import component_class
 - from agent.component.base import ComponentBase
 - from api.db.services.file_service import FileService
 - from api.utils import get_uuid, hash_str2int
 - from rag.prompts.prompts import chunks_format
 - from rag.utils.redis_conn import REDIS_CONN
 - 
 - 
 - class Canvas:
 -     """
 -     dsl = {
 -         "components": {
 -             "begin": {
 -                 "obj":{
 -                     "component_name": "Begin",
 -                     "params": {},
 -                 },
 -                 "downstream": ["answer_0"],
 -                 "upstream": [],
 -             },
 -             "retrieval_0": {
 -                 "obj": {
 -                     "component_name": "Retrieval",
 -                     "params": {}
 -                 },
 -                 "downstream": ["generate_0"],
 -                 "upstream": ["answer_0"],
 -             },
 -             "generate_0": {
 -                 "obj": {
 -                     "component_name": "Generate",
 -                     "params": {}
 -                 },
 -                 "downstream": ["answer_0"],
 -                 "upstream": ["retrieval_0"],
 -             }
 -         },
 -         "history": [],
 -         "path": ["begin"],
 -         "retrieval": {"chunks": [], "doc_aggs": []},
 -         "globals": {
 -             "sys.query": "",
 -             "sys.user_id": tenant_id,
 -             "sys.conversation_turns": 0,
 -             "sys.files": []
 -         }
 -     }
 -     """
 - 
 -     def __init__(self, dsl: str, tenant_id=None, task_id=None):
 -         self.path = []
 -         self.history = []
 -         self.components = {}
 -         self.error = ""
 -         self.globals = {
 -             "sys.query": "",
 -             "sys.user_id": tenant_id,
 -             "sys.conversation_turns": 0,
 -             "sys.files": []
 -         }
 -         self.dsl = json.loads(dsl) if dsl else {
 -             "components": {
 -                 "begin": {
 -                     "obj": {
 -                         "component_name": "Begin",
 -                         "params": {
 -                             "prologue": "Hi there!"
 -                         }
 -                     },
 -                     "downstream": [],
 -                     "upstream": [],
 -                     "parent_id": ""
 -                 }
 -             },
 -             "history": [],
 -             "path": [],
 -             "retrieval": [],
 -             "globals": {
 -                 "sys.query": "",
 -                 "sys.user_id": "",
 -                 "sys.conversation_turns": 0,
 -                 "sys.files": []
 -             }
 -         }
 -         self._tenant_id = tenant_id
 -         self.task_id = task_id if task_id else get_uuid()
 -         self.load()
 - 
 -     def load(self):
 -         self.components = self.dsl["components"]
 -         cpn_nms = set([])
 -         for k, cpn in self.components.items():
 -             cpn_nms.add(cpn["obj"]["component_name"])
 - 
 -         assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
 - 
 -         for k, cpn in self.components.items():
 -             cpn_nms.add(cpn["obj"]["component_name"])
 -             param = component_class(cpn["obj"]["component_name"] + "Param")()
 -             param.update(cpn["obj"]["params"])
 -             try:
 -                 param.check()
 -             except Exception as e:
 -                 raise ValueError(self.get_component_name(k) + f": {e}")
 - 
 -             cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
 - 
 -         self.path = self.dsl["path"]
 -         self.history = self.dsl["history"]
 -         self.globals = self.dsl["globals"]
 -         self.retrieval = self.dsl["retrieval"]
 -         self.memory = self.dsl.get("memory", [])
 - 
 -     def __str__(self):
 -         self.dsl["path"] = self.path
 -         self.dsl["history"] = self.history
 -         self.dsl["globals"] = self.globals
 -         self.dsl["task_id"] = self.task_id
 -         self.dsl["retrieval"] = self.retrieval
 -         self.dsl["memory"] = self.memory
 -         dsl = {
 -             "components": {}
 -         }
 -         for k in self.dsl.keys():
 -             if k in ["components"]:
 -                 continue
 -             dsl[k] = deepcopy(self.dsl[k])
 - 
 -         for k, cpn in self.components.items():
 -             if k not in dsl["components"]:
 -                 dsl["components"][k] = {}
 -             for c in cpn.keys():
 -                 if c == "obj":
 -                     dsl["components"][k][c] = json.loads(str(cpn["obj"]))
 -                     continue
 -                 dsl["components"][k][c] = deepcopy(cpn[c])
 -         return json.dumps(dsl, ensure_ascii=False)
 - 
 -     def reset(self, mem=False):
 -         self.path = []
 -         if not mem:
 -             self.history = []
 -             self.retrieval = []
 -             self.memory = []
 -         for k, cpn in self.components.items():
 -             self.components[k]["obj"].reset()
 - 
 -         for k in self.globals.keys():
 -             if isinstance(self.globals[k], str):
 -                 self.globals[k] = ""
 -             elif isinstance(self.globals[k], int):
 -                 self.globals[k] = 0
 -             elif isinstance(self.globals[k], float):
 -                 self.globals[k] = 0
 -             elif isinstance(self.globals[k], list):
 -                 self.globals[k] = []
 -             elif isinstance(self.globals[k], dict):
 -                 self.globals[k] = {}
 -             else:
 -                 self.globals[k] = None
 - 
 -         try:
 -             REDIS_CONN.delete(f"{self.task_id}-logs")
 -         except Exception as e:
 -             logging.exception(e)
 - 
 -     def get_component_name(self, cid):
 -         for n in self.dsl.get("graph", {}).get("nodes", []):
 -             if cid == n["id"]:
 -                 return n["data"]["name"]
 -         return ""
 - 
 -     def run(self, **kwargs):
 -         st = time.perf_counter()
 -         self.message_id = get_uuid()
 -         created_at = int(time.time())
 -         self.add_user_input(kwargs.get("query"))
 - 
 -         for k in kwargs.keys():
 -             if k in ["query", "user_id", "files"] and kwargs[k]:
 -                 if k == "files":
 -                     self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
 -                 else:
 -                     self.globals[f"sys.{k}"] = kwargs[k]
 -         if not self.globals["sys.conversation_turns"] :
 -             self.globals["sys.conversation_turns"] = 0
 -         self.globals["sys.conversation_turns"] += 1
 - 
 -         def decorate(event, dt):
 -             nonlocal created_at
 -             return {
 -                 "event": event,
 -                 #"conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115",
 -                 "message_id": self.message_id,
 -                 "created_at": created_at,
 -                 "task_id": self.task_id,
 -                 "data": dt
 -             }
 - 
 -         if not self.path or self.path[-1].lower().find("userfillup") < 0:
 -             self.path.append("begin")
 -             self.retrieval.append({"chunks": [], "doc_aggs": []})
 - 
 -         yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
 -         self.retrieval.append({"chunks": {}, "doc_aggs": {}})
 - 
 -         def _run_batch(f, t):
 -             with ThreadPoolExecutor(max_workers=5) as executor:
 -                 thr = []
 -                 for i in range(f, t):
 -                     cpn = self.get_component_obj(self.path[i])
 -                     if cpn.component_name.lower() in ["begin", "userfillup"]:
 -                         thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
 -                     else:
 -                         thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
 -                 for t in thr:
 -                     t.result()
 - 
 -         def _node_finished(cpn_obj):
 -             return decorate("node_finished",{
 -                            "inputs": cpn_obj.get_input_values(),
 -                            "outputs": cpn_obj.output(),
 -                            "component_id": cpn_obj._id,
 -                            "component_name": self.get_component_name(cpn_obj._id),
 -                            "component_type": self.get_component_type(cpn_obj._id),
 -                            "error": cpn_obj.error(),
 -                            "elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"),
 -                            "created_at": cpn_obj.output("_created_time"),
 -                        })
 - 
 -         self.error = ""
 -         idx = len(self.path) - 1
 -         partials = []
 -         while idx < len(self.path):
 -             to = len(self.path)
 -             for i in range(idx, to):
 -                 yield decorate("node_started", {
 -                     "inputs": None, "created_at": int(time.time()),
 -                     "component_id": self.path[i],
 -                     "component_name": self.get_component_name(self.path[i]),
 -                     "component_type": self.get_component_type(self.path[i]),
 -                     "thoughts": self.get_component_thoughts(self.path[i])
 -                 })
 -             _run_batch(idx, to)
 - 
 -             # post processing of components invocation
 -             for i in range(idx, to):
 -                 cpn = self.get_component(self.path[i])
 -                 cpn_obj = self.get_component_obj(self.path[i])
 -                 if cpn_obj.component_name.lower() == "message":
 -                     if isinstance(cpn_obj.output("content"), partial):
 -                         _m = ""
 -                         for m in cpn_obj.output("content")():
 -                             if not m:
 -                                 continue
 -                             if m == "<think>":
 -                                 yield decorate("message", {"content": "", "start_to_think": True})
 -                             elif m == "</think>":
 -                                 yield decorate("message", {"content": "", "end_to_think": True})
 -                             else:
 -                                 yield decorate("message", {"content": m})
 -                                 _m += m
 -                         cpn_obj.set_output("content", _m)
 -                     else:
 -                         yield decorate("message", {"content": cpn_obj.output("content")})
 -                     yield decorate("message_end", {"reference": self.get_reference()})
 - 
 -                     while partials:
 -                         _cpn_obj = self.get_component_obj(partials[0])
 -                         if isinstance(_cpn_obj.output("content"), partial):
 -                             break
 -                         yield _node_finished(_cpn_obj)
 -                         partials.pop(0)
 - 
 -                 other_branch = False
 -                 if cpn_obj.error():
 -                     ex = cpn_obj.exception_handler()
 -                     if ex and ex["goto"]:
 -                         self.path.extend(ex["goto"])
 -                         other_branch = True
 -                     elif ex and ex["default_value"]:
 -                         yield decorate("message", {"content": ex["default_value"]})
 -                         yield decorate("message_end", {})
 -                     else:
 -                         self.error = cpn_obj.error()
 - 
 -                 if cpn_obj.component_name.lower() != "iteration":
 -                     if isinstance(cpn_obj.output("content"), partial):
 -                         if self.error:
 -                             cpn_obj.set_output("content", None)
 -                             yield _node_finished(cpn_obj)
 -                         else:
 -                             partials.append(self.path[i])
 -                     else:
 -                         yield _node_finished(cpn_obj)
 - 
 -                 def _append_path(cpn_id):
 -                     nonlocal other_branch
 -                     if other_branch:
 -                         return
 -                     if self.path[-1] == cpn_id:
 -                         return
 -                     self.path.append(cpn_id)
 - 
 -                 def _extend_path(cpn_ids):
 -                     nonlocal other_branch
 -                     if other_branch:
 -                         return
 -                     for cpn_id in cpn_ids:
 -                         _append_path(cpn_id)
 - 
 -                 if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end():
 -                     iter = cpn_obj.get_parent()
 -                     yield _node_finished(iter)
 -                     _extend_path(self.get_component(cpn["parent_id"])["downstream"])
 -                 elif cpn_obj.component_name.lower() in ["categorize", "switch"]:
 -                     _extend_path(cpn_obj.output("_next"))
 -                 elif cpn_obj.component_name.lower() == "iteration":
 -                     _append_path(cpn_obj.get_start())
 -                 elif not cpn["downstream"] and cpn_obj.get_parent():
 -                     _append_path(cpn_obj.get_parent().get_start())
 -                 else:
 -                     _extend_path(cpn["downstream"])
 - 
 -             if self.error:
 -                 logging.error(f"Runtime Error: {self.error}")
 -                 break
 -             idx = to
 - 
 -             if any([self.get_component_obj(c).component_name.lower() == "userfillup" for c in self.path[idx:]]):
 -                 path = [c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() == "userfillup"]
 -                 path.extend([c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() != "userfillup"])
 -                 another_inputs = {}
 -                 tips = ""
 -                 for c in path:
 -                     o = self.get_component_obj(c)
 -                     if o.component_name.lower() == "userfillup":
 -                         another_inputs.update(o.get_input_elements())
 -                         if o.get_param("enable_tips"):
 -                             tips = o.get_param("tips")
 -                 self.path = path
 -                 yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips})
 -                 return
 - 
 -         self.path = self.path[:idx]
 -         if not self.error:
 -             yield decorate("workflow_finished",
 -                        {
 -                            "inputs": kwargs.get("inputs"),
 -                            "outputs": self.get_component_obj(self.path[-1]).output(),
 -                            "elapsed_time": time.perf_counter() - st,
 -                            "created_at": st,
 -                        })
 -             self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
 - 
 -     def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
 -         return self.components.get(cpn_id)
 - 
 -     def get_component_obj(self, cpn_id) -> ComponentBase:
 -         return self.components.get(cpn_id)["obj"]
 - 
 -     def get_component_type(self, cpn_id) -> str:
 -         return self.components.get(cpn_id)["obj"].component_name
 - 
 -     def get_component_input_form(self, cpn_id) -> dict:
 -         return self.components.get(cpn_id)["obj"].get_input_form()
 - 
 -     def is_reff(self, exp: str) -> bool:
 -         exp = exp.strip("{").strip("}")
 -         if exp.find("@") < 0:
 -             return exp in self.globals
 -         arr = exp.split("@")
 -         if len(arr) != 2:
 -             return False
 -         if self.get_component(arr[0]) is None:
 -             return False
 -         return True
 - 
 -     def get_variable_value(self, exp: str) -> Any:
 -         exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
 -         if exp.find("@") < 0:
 -             return self.globals[exp]
 -         cpn_id, var_nm = exp.split("@")
 -         cpn = self.get_component(cpn_id)
 -         if not cpn:
 -             raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
 -         return cpn["obj"].output(var_nm)
 - 
 -     def get_tenant_id(self):
 -         return self._tenant_id
 - 
 -     def get_history(self, window_size):
 -         convs = []
 -         if window_size <= 0:
 -             return convs
 -         for role, obj in self.history[window_size * -1:]:
 -             if isinstance(obj, dict):
 -                 convs.append({"role": role, "content": obj.get("content", "")})
 -             else:
 -                 convs.append({"role": role, "content": str(obj)})
 -         return convs
 - 
 -     def add_user_input(self, question):
 -         self.history.append(("user", question))
 - 
 -     def _find_loop(self, max_loops=6):
 -         path = self.path[-1][::-1]
 -         if len(path) < 2:
 -             return False
 - 
 -         for i in range(len(path)):
 -             if path[i].lower().find("answer") == 0 or path[i].lower().find("iterationitem") == 0:
 -                 path = path[:i]
 -                 break
 - 
 -         if len(path) < 2:
 -             return False
 - 
 -         for loc in range(2, len(path) // 2):
 -             pat = ",".join(path[0:loc])
 -             path_str = ",".join(path)
 -             if len(pat) >= len(path_str):
 -                 return False
 -             loop = max_loops
 -             while path_str.find(pat) == 0 and loop >= 0:
 -                 loop -= 1
 -                 if len(pat)+1 >= len(path_str):
 -                     return False
 -                 path_str = path_str[len(pat)+1:]
 -             if loop < 0:
 -                 pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
 -                 return pat + " => " + pat
 - 
 -         return False
 - 
 -     def get_prologue(self):
 -         return self.components["begin"]["obj"]._param.prologue
 - 
 -     def set_global_param(self, **kwargs):
 -         self.globals.update(kwargs)
 - 
 -     def get_preset_param(self):
 -         return self.components["begin"]["obj"]._param.inputs
 - 
 -     def get_component_input_elements(self, cpnnm):
 -         return self.components[cpnnm]["obj"].get_input_elements()
 - 
 -     def get_files(self, files: Union[None, list[dict]]) -> list[str]:
 -         if not files:
 -             return  []
 -         def image_to_base64(file):
 -             return "data:{};base64,{}".format(file["mime_type"],
 -                                         base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
 -         exe = ThreadPoolExecutor(max_workers=5)
 -         threads = []
 -         for file in files:
 -             if file["mime_type"].find("image") >=0:
 -                 threads.append(exe.submit(image_to_base64, file))
 -                 continue
 -             threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
 -         return [th.result() for th in threads]
 - 
 -     def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any):
 -         agent_ids = agent_id.split("-->")
 -         agent_name = self.get_component_name(agent_ids[0])
 -         path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:])
 -         try:
 -             bin = REDIS_CONN.get(f"{self.task_id}-{self.message_id}-logs")
 -             if bin:
 -                 obj = json.loads(bin.encode("utf-8"))
 -                 if obj[-1]["component_id"] == agent_ids[0]:
 -                     obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result})
 -                 else:
 -                     obj.append({
 -                     "component_id": agent_ids[0],
 -                     "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}]
 -                 })
 -             else:
 -                 obj = [{
 -                     "component_id": agent_ids[0],
 -                     "trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}]
 -                 }]
 -             REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10)
 -         except Exception as e:
 -             logging.exception(e)
 - 
 -     def add_refernce(self, chunks: list[object], doc_infos: list[object]):
 -         if not self.retrieval:
 -             self.retrieval = [{"chunks": {}, "doc_aggs": {}}]
 - 
 -         r = self.retrieval[-1]
 -         for ck in chunks_format({"chunks": chunks}):
 -             cid = hash_str2int(ck["id"], 100)
 -             if cid not in r:
 -                 r["chunks"][cid] = ck
 - 
 -         for doc in doc_infos:
 -             if doc["doc_name"] not in r:
 -                 r["doc_aggs"][doc["doc_name"]] = doc
 - 
 -     def get_reference(self):
 -         if not self.retrieval:
 -             return {"chunks": {}, "doc_aggs": {}}
 -         return self.retrieval[-1]
 - 
 -     def add_memory(self, user:str, assist:str, summ: str):
 -         self.memory.append((user, assist, summ))
 - 
 -     def get_memory(self) -> list[Tuple]:
 -         return self.memory
 - 
 -     def get_component_thoughts(self, cpn_id) -> str:
 -         return self.components.get(cpn_id)["obj"].thoughts()
 
 
  |