You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import re
  18. import time
  19. from copy import deepcopy
  20. from functools import partial
  21. from typing import TypedDict, List, Any
  22. from agent.component.base import ComponentParamBase, ComponentBase
  23. from api.utils import hash_str2int
  24. from rag.llm.chat_model import ToolCallSession
  25. from rag.prompts.prompts import kb_prompt
  26. from rag.utils.mcp_tool_call_conn import MCPToolCallSession
  27. from timeit import default_timer as timer
  28. class ToolParameter(TypedDict):
  29. type: str
  30. description: str
  31. displayDescription: str
  32. enum: List[str]
  33. required: bool
  34. class ToolMeta(TypedDict):
  35. name: str
  36. displayName: str
  37. description: str
  38. displayDescription: str
  39. parameters: dict[str, ToolParameter]
  40. class LLMToolPluginCallSession(ToolCallSession):
  41. def __init__(self, tools_map: dict[str, object], callback: partial):
  42. self.tools_map = tools_map
  43. self.callback = callback
  44. def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
  45. assert name in self.tools_map, f"LLM tool {name} does not exist"
  46. st = timer()
  47. if isinstance(self.tools_map[name], MCPToolCallSession):
  48. resp = self.tools_map[name].tool_call(name, arguments, 60)
  49. else:
  50. resp = self.tools_map[name].invoke(**arguments)
  51. self.callback(name, arguments, resp, elapsed_time=timer()-st)
  52. return resp
  53. def get_tool_obj(self, name):
  54. return self.tools_map[name]
  55. class ToolParamBase(ComponentParamBase):
  56. def __init__(self):
  57. #self.meta:ToolMeta = None
  58. super().__init__()
  59. self._init_inputs()
  60. self._init_attr_by_meta()
  61. def _init_inputs(self):
  62. self.inputs = {}
  63. for k,p in self.meta["parameters"].items():
  64. self.inputs[k] = deepcopy(p)
  65. def _init_attr_by_meta(self):
  66. for k,p in self.meta["parameters"].items():
  67. if not hasattr(self, k):
  68. setattr(self, k, p.get("default"))
  69. def get_meta(self):
  70. params = {}
  71. for k, p in self.meta["parameters"].items():
  72. params[k] = {
  73. "type": p["type"],
  74. "description": p["description"]
  75. }
  76. if "enum" in p:
  77. params[k]["enum"] = p["enum"]
  78. desc = self.meta["description"]
  79. if hasattr(self, "description"):
  80. desc = self.description
  81. function_name = self.meta["name"]
  82. if hasattr(self, "function_name"):
  83. function_name = self.function_name
  84. return {
  85. "type": "function",
  86. "function": {
  87. "name": function_name,
  88. "description": desc,
  89. "parameters": {
  90. "type": "object",
  91. "properties": params,
  92. "required": [k for k, p in self.meta["parameters"].items() if p["required"]]
  93. }
  94. }
  95. }
  96. class ToolBase(ComponentBase):
  97. def __init__(self, canvas, id, param: ComponentParamBase):
  98. from agent.canvas import Canvas # Local import to avoid cyclic dependency
  99. assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
  100. self._canvas = canvas
  101. self._id = id
  102. self._param = param
  103. self._param.check()
  104. def get_meta(self) -> dict[str, Any]:
  105. return self._param.get_meta()
  106. def invoke(self, **kwargs):
  107. self.set_output("_created_time", time.perf_counter())
  108. try:
  109. res = self._invoke(**kwargs)
  110. except Exception as e:
  111. self._param.outputs["_ERROR"] = {"value": str(e)}
  112. logging.exception(e)
  113. res = str(e)
  114. self._param.debug_inputs = []
  115. self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
  116. return res
  117. def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
  118. chunks = []
  119. aggs = []
  120. for r in res_list:
  121. content = get_content(r)
  122. if not content:
  123. continue
  124. content = re.sub(r"!?\[[a-z]+\]\(data:image/png;base64,[ 0-9A-Za-z/_=+-]+\)", "", content)
  125. content = content[:10000]
  126. if not content:
  127. continue
  128. id = str(hash_str2int(content))
  129. title = get_title(r)
  130. url = get_url(r)
  131. score = get_score(r) if get_score else 1
  132. chunks.append({
  133. "chunk_id": id,
  134. "content": content,
  135. "doc_id": id,
  136. "docnm_kwd": title,
  137. "similarity": score,
  138. "url": url
  139. })
  140. aggs.append({
  141. "doc_name": title,
  142. "doc_id": id,
  143. "count": 1,
  144. "url": url
  145. })
  146. self._canvas.add_refernce(chunks, aggs)
  147. self.set_output("formalized_content", "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": aggs}, 200000, True)))
  148. def thoughts(self) -> str:
  149. return self._canvas.get_component_name(self._id) + " is running..."