# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import re import time from copy import deepcopy from functools import partial from typing import TypedDict, List, Any from agent.component.base import ComponentParamBase, ComponentBase from api.utils import hash_str2int from rag.llm.chat_model import ToolCallSession from rag.prompts.prompts import kb_prompt from rag.utils.mcp_tool_call_conn import MCPToolCallSession class ToolParameter(TypedDict): type: str description: str displayDescription: str enum: List[str] required: bool class ToolMeta(TypedDict): name: str displayName: str description: str displayDescription: str parameters: dict[str, ToolParameter] class LLMToolPluginCallSession(ToolCallSession): def __init__(self, tools_map: dict[str, object], callback: partial): self.tools_map = tools_map self.callback = callback def tool_call(self, name: str, arguments: dict[str, Any]) -> Any: assert name in self.tools_map, f"LLM tool {name} does not exist" if isinstance(self.tools_map[name], MCPToolCallSession): resp = self.tools_map[name].tool_call(name, arguments, 60) else: resp = self.tools_map[name].invoke(**arguments) self.callback(name, arguments, resp) return resp def get_tool_obj(self, name): return self.tools_map[name] class ToolParamBase(ComponentParamBase): def __init__(self): #self.meta:ToolMeta = None super().__init__() self._init_inputs() self._init_attr_by_meta() def _init_inputs(self): self.inputs = {} for k,p in self.meta["parameters"].items(): self.inputs[k] = deepcopy(p) def _init_attr_by_meta(self): for k,p in self.meta["parameters"].items(): if not hasattr(self, k): setattr(self, k, p.get("default")) def get_meta(self): params = {} for k, p in self.meta["parameters"].items(): params[k] = { "type": p["type"], "description": p["description"] } if "enum" in p: params[k]["enum"] = p["enum"] desc = self.meta["description"] if hasattr(self, "description"): desc = self.description function_name = self.meta["name"] if hasattr(self, "function_name"): function_name = self.function_name return { "type": "function", "function": { "name": function_name, "description": desc, "parameters": { "type": "object", "properties": params, "required": [k for k, p in self.meta["parameters"].items() if p["required"]] } } } class ToolBase(ComponentBase): def __init__(self, canvas, id, param: ComponentParamBase): from agent.canvas import Canvas # Local import to avoid cyclic dependency assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas" self._canvas = canvas self._id = id self._param = param self._param.check() def get_meta(self) -> dict[str, Any]: return self._param.get_meta() def invoke(self, **kwargs): self.set_output("_created_time", time.perf_counter()) try: res = self._invoke(**kwargs) except Exception as e: self._param.outputs["_ERROR"] = {"value": str(e)} logging.exception(e) res = str(e) self._param.debug_inputs = [] self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) return res def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None): chunks = [] aggs = [] for r in res_list: content = get_content(r) if not content: continue content = re.sub(r"!?\[[a-z]+\]\(data:image/png;base64,[ 0-9A-Za-z/_=+-]+\)", "", content) content = content[:10000] if not content: continue id = str(hash_str2int(content)) title = get_title(r) url = get_url(r) score = get_score(r) if get_score else 1 chunks.append({ "chunk_id": id, "content": content, "doc_id": id, "docnm_kwd": title, "similarity": score, "url": url }) aggs.append({ "doc_name": title, "doc_id": id, "count": 1, "url": url }) self._canvas.add_refernce(chunks, aggs) self.set_output("formalized_content", "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": aggs}, 200000, True))) def thoughts(self) -> str: return self._canvas.get_component_name(self._id) + " is running..."