You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

agent_with_tools.py 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import re
  19. from concurrent.futures import ThreadPoolExecutor
  20. from copy import deepcopy
  21. from functools import partial
  22. from typing import Any
  23. import json_repair
  24. from agent.component.llm import LLMParam, LLM
  25. from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
  26. from api.db.services.llm_service import LLMBundle, TenantLLMService
  27. from api.db.services.mcp_server_service import MCPServerService
  28. from api.utils.api_utils import timeout
  29. from rag.prompts import message_fit_in
  30. from rag.prompts.prompts import next_step, COMPLETE_TASK, analyze_task, \
  31. citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question
  32. from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
  33. class AgentParam(LLMParam, ToolParamBase):
  34. """
  35. Define the Agent component parameters.
  36. """
  37. def __init__(self):
  38. self.meta:ToolMeta = {
  39. "name": "agent",
  40. "description": "This is an agent for a specific task.",
  41. "parameters": {
  42. "user_prompt": {
  43. "type": "string",
  44. "description": "This is the order you need to send to the agent.",
  45. "default": "",
  46. "required": True
  47. },
  48. "reasoning": {
  49. "type": "string",
  50. "description": (
  51. "Supervisor's reasoning for choosing the this agent. "
  52. "Explain why this agent is being invoked and what is expected of it."
  53. ),
  54. "required": True
  55. },
  56. "context": {
  57. "type": "string",
  58. "description": (
  59. "All relevant background information, prior facts, decisions, "
  60. "and state needed by the agent to solve the current query. "
  61. "Should be as detailed and self-contained as possible."
  62. ),
  63. "required": True
  64. },
  65. }
  66. }
  67. super().__init__()
  68. self.function_name = "agent"
  69. self.tools = []
  70. self.mcp = []
  71. self.max_rounds = 5
  72. self.description = ""
  73. class Agent(LLM, ToolBase):
  74. component_name = "Agent"
  75. def __init__(self, canvas, id, param: LLMParam):
  76. LLM.__init__(self, canvas, id, param)
  77. self.tools = {}
  78. for cpn in self._param.tools:
  79. cpn = self._load_tool_obj(cpn)
  80. self.tools[cpn.get_meta()["function"]["name"]] = cpn
  81. self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
  82. max_retries=self._param.max_retries,
  83. retry_interval=self._param.delay_after_error,
  84. max_rounds=self._param.max_rounds,
  85. verbose_tool_use=True
  86. )
  87. self.tool_meta = [v.get_meta() for _,v in self.tools.items()]
  88. for mcp in self._param.mcp:
  89. _, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
  90. tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
  91. for tnm, meta in mcp["tools"].items():
  92. self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta))
  93. self.tools[tnm] = tool_call_session
  94. self.callback = partial(self._canvas.tool_use_callback, id)
  95. self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback)
  96. #self.chat_mdl.bind_tools(self.toolcall_session, self.tool_metas)
  97. def _load_tool_obj(self, cpn: dict) -> object:
  98. from agent.component import component_class
  99. param = component_class(cpn["component_name"] + "Param")()
  100. param.update(cpn["params"])
  101. try:
  102. param.check()
  103. except Exception as e:
  104. self.set_output("_ERROR", cpn["component_name"] + f" configuration error: {e}")
  105. raise
  106. cpn_id = f"{self._id}-->" + cpn.get("name", "").replace(" ", "_")
  107. return component_class(cpn["component_name"])(self._canvas, cpn_id, param)
  108. def get_meta(self) -> dict[str, Any]:
  109. self._param.function_name= self._id.split("-->")[-1]
  110. m = super().get_meta()
  111. if hasattr(self._param, "user_prompt") and self._param.user_prompt:
  112. m["function"]["parameters"]["properties"]["user_prompt"] = self._param.user_prompt
  113. return m
  114. def get_input_form(self) -> dict[str, dict]:
  115. res = {}
  116. for k, v in self.get_input_elements().items():
  117. res[k] = {
  118. "type": "line",
  119. "name": v["name"]
  120. }
  121. for cpn in self._param.tools:
  122. if not isinstance(cpn, LLM):
  123. continue
  124. res.update(cpn.get_input_form())
  125. return res
  126. @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))
  127. def _invoke(self, **kwargs):
  128. if kwargs.get("user_prompt"):
  129. usr_pmt = ""
  130. if kwargs.get("reasoning"):
  131. usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
  132. if kwargs.get("context"):
  133. usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
  134. if usr_pmt:
  135. usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
  136. else:
  137. usr_pmt = str(kwargs["user_prompt"])
  138. self._param.prompts = [{"role": "user", "content": usr_pmt}]
  139. if not self.tools:
  140. return LLM._invoke(self, **kwargs)
  141. prompt, msg = self._prepare_prompt_variables()
  142. downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
  143. if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure:
  144. self.set_output("content", partial(self.stream_output_with_tools, prompt, msg))
  145. return
  146. _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
  147. use_tools = []
  148. ans = ""
  149. for delta_ans, tk in self._react_with_tools_streamly(msg, use_tools):
  150. ans += delta_ans
  151. if ans.find("**ERROR**") >= 0:
  152. logging.error(f"Agent._chat got error. response: {ans}")
  153. self.set_output("_ERROR", ans)
  154. return
  155. self.set_output("content", ans)
  156. if use_tools:
  157. self.set_output("use_tools", use_tools)
  158. return ans
  159. def stream_output_with_tools(self, prompt, msg):
  160. _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
  161. answer_without_toolcall = ""
  162. use_tools = []
  163. for delta_ans,_ in self._react_with_tools_streamly(msg, use_tools):
  164. answer_without_toolcall += delta_ans
  165. yield delta_ans
  166. self.set_output("content", answer_without_toolcall)
  167. if use_tools:
  168. self.set_output("use_tools", use_tools)
  169. def _gen_citations(self, text):
  170. retrievals = self._canvas.get_reference()
  171. retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
  172. formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
  173. for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
  174. {"role": "user", "content": text}
  175. ]):
  176. yield delta_ans
  177. def _react_with_tools_streamly(self, history: list[dict], use_tools):
  178. token_count = 0
  179. tool_metas = self.tool_meta
  180. hist = deepcopy(history)
  181. last_calling = ""
  182. if len(hist) > 3:
  183. self.callback("Multi-turn conversation optimization", {}, " running ...")
  184. user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
  185. else:
  186. user_request = history[-1]["content"]
  187. def use_tool(name, args):
  188. nonlocal hist, use_tools, token_count,last_calling,user_request
  189. print(f"{last_calling=} == {name=}", )
  190. # Summarize of function calling
  191. #if all([
  192. # isinstance(self.toolcall_session.get_tool_obj(name), Agent),
  193. # last_calling,
  194. # last_calling != name
  195. #]):
  196. # self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"])))
  197. last_calling = name
  198. tool_response = self.toolcall_session.tool_call(name, args)
  199. use_tools.append({
  200. "name": name,
  201. "arguments": args,
  202. "results": tool_response
  203. })
  204. # self.callback("add_memory", {}, "...")
  205. #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response))
  206. return name, tool_response
  207. def complete():
  208. nonlocal hist
  209. need2cite = self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
  210. cited = False
  211. if hist[0]["role"] == "system" and need2cite:
  212. if len(hist) < 7:
  213. hist[0]["content"] += citation_prompt()
  214. cited = True
  215. yield "", token_count
  216. if not cited and need2cite:
  217. self.callback("gen_citations", {}, " running ...")
  218. _hist = hist
  219. if len(hist) > 12:
  220. _hist = [hist[0], hist[1], *hist[-10:]]
  221. entire_txt = ""
  222. for delta_ans in self._generate_streamly(_hist):
  223. if not need2cite or cited:
  224. yield delta_ans, 0
  225. entire_txt += delta_ans
  226. if not need2cite or cited:
  227. return
  228. for delta_ans in self._gen_citations(entire_txt):
  229. yield delta_ans, 0
  230. def append_user_content(hist, content):
  231. if hist[-1]["role"] == "user":
  232. hist[-1]["content"] += content
  233. else:
  234. hist.append({"role": "user", "content": content})
  235. self.callback("analyze_task", {}, " running ...")
  236. task_desc = analyze_task(self.chat_mdl, user_request, tool_metas)
  237. for _ in range(self._param.max_rounds + 1):
  238. response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
  239. # self.callback("next_step", {}, str(response)[:256]+"...")
  240. token_count += tk
  241. hist.append({"role": "assistant", "content": response})
  242. try:
  243. functions = json_repair.loads(re.sub(r"```.*", "", response))
  244. if not isinstance(functions, list):
  245. raise TypeError(f"List should be returned, but `{functions}`")
  246. for f in functions:
  247. if not isinstance(f, dict):
  248. raise TypeError(f"An object type should be returned, but `{f}`")
  249. with ThreadPoolExecutor(max_workers=5) as executor:
  250. thr = []
  251. for func in functions:
  252. name = func["name"]
  253. args = func["arguments"]
  254. if name == COMPLETE_TASK:
  255. append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
  256. for txt, tkcnt in complete():
  257. yield txt, tkcnt
  258. return
  259. thr.append(executor.submit(use_tool, name, args))
  260. reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr])
  261. append_user_content(hist, reflection)
  262. self.callback("reflection", {}, str(reflection))
  263. except Exception as e:
  264. logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
  265. e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}"
  266. append_user_content(hist, str(e))
  267. logging.warning( f"Exceed max rounds: {self._param.max_rounds}")
  268. final_instruction = f"""
  269. {user_request}
  270. IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request.
  271. Instructions:
  272. 1. SYNTHESIZE all information collected during this conversation
  273. 2. Provide a COMPLETE response using existing data - do not suggest additional research
  274. 3. Structure your response as a FINAL DELIVERABLE, not a plan
  275. 4. If information is incomplete, state what you found and provide the best analysis possible with available data
  276. 5. DO NOT mention conversation limits or suggest further steps
  277. 6. Focus on delivering VALUE with the information already gathered
  278. Respond immediately with your final comprehensive answer.
  279. """
  280. append_user_content(hist, final_instruction)
  281. for txt, tkcnt in complete():
  282. yield txt, tkcnt
  283. def get_useful_memory(self, goal: str, sub_goal:str, topn=3) -> str:
  284. # self.callback("get_useful_memory", {"topn": 3}, "...")
  285. mems = self._canvas.get_memory()
  286. rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems])
  287. try:
  288. rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
  289. mems = [mems[r] for r in rank]
  290. return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
  291. except Exception as e:
  292. logging.exception(e)
  293. return "Error occurred."