Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

agent_with_tools.py 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import re
  19. from concurrent.futures import ThreadPoolExecutor
  20. from copy import deepcopy
  21. from functools import partial
  22. from typing import Any
  23. import json_repair
  24. from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
  25. from api.db.services.llm_service import LLMBundle
  26. from api.db.services.tenant_llm_service import TenantLLMService
  27. from api.db.services.mcp_server_service import MCPServerService
  28. from api.utils.api_utils import timeout
  29. from rag.prompts import message_fit_in
  30. from rag.prompts.prompts import next_step, COMPLETE_TASK, analyze_task, \
  31. citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question
  32. from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
  33. from agent.component.llm import LLMParam, LLM
  34. class AgentParam(LLMParam, ToolParamBase):
  35. """
  36. Define the Agent component parameters.
  37. """
  38. def __init__(self):
  39. self.meta:ToolMeta = {
  40. "name": "agent",
  41. "description": "This is an agent for a specific task.",
  42. "parameters": {
  43. "user_prompt": {
  44. "type": "string",
  45. "description": "This is the order you need to send to the agent.",
  46. "default": "",
  47. "required": True
  48. },
  49. "reasoning": {
  50. "type": "string",
  51. "description": (
  52. "Supervisor's reasoning for choosing the this agent. "
  53. "Explain why this agent is being invoked and what is expected of it."
  54. ),
  55. "required": True
  56. },
  57. "context": {
  58. "type": "string",
  59. "description": (
  60. "All relevant background information, prior facts, decisions, "
  61. "and state needed by the agent to solve the current query. "
  62. "Should be as detailed and self-contained as possible."
  63. ),
  64. "required": True
  65. },
  66. }
  67. }
  68. super().__init__()
  69. self.function_name = "agent"
  70. self.tools = []
  71. self.mcp = []
  72. self.max_rounds = 5
  73. self.description = ""
  74. class Agent(LLM, ToolBase):
  75. component_name = "Agent"
  76. def __init__(self, canvas, id, param: LLMParam):
  77. LLM.__init__(self, canvas, id, param)
  78. self.tools = {}
  79. for cpn in self._param.tools:
  80. cpn = self._load_tool_obj(cpn)
  81. self.tools[cpn.get_meta()["function"]["name"]] = cpn
  82. self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
  83. max_retries=self._param.max_retries,
  84. retry_interval=self._param.delay_after_error,
  85. max_rounds=self._param.max_rounds,
  86. verbose_tool_use=True
  87. )
  88. self.tool_meta = [v.get_meta() for _,v in self.tools.items()]
  89. for mcp in self._param.mcp:
  90. _, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
  91. tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
  92. for tnm, meta in mcp["tools"].items():
  93. self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta))
  94. self.tools[tnm] = tool_call_session
  95. self.callback = partial(self._canvas.tool_use_callback, id)
  96. self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback)
  97. #self.chat_mdl.bind_tools(self.toolcall_session, self.tool_metas)
  98. def _load_tool_obj(self, cpn: dict) -> object:
  99. from agent.component import component_class
  100. param = component_class(cpn["component_name"] + "Param")()
  101. param.update(cpn["params"])
  102. try:
  103. param.check()
  104. except Exception as e:
  105. self.set_output("_ERROR", cpn["component_name"] + f" configuration error: {e}")
  106. raise
  107. cpn_id = f"{self._id}-->" + cpn.get("name", "").replace(" ", "_")
  108. return component_class(cpn["component_name"])(self._canvas, cpn_id, param)
  109. def get_meta(self) -> dict[str, Any]:
  110. self._param.function_name= self._id.split("-->")[-1]
  111. m = super().get_meta()
  112. if hasattr(self._param, "user_prompt") and self._param.user_prompt:
  113. m["function"]["parameters"]["properties"]["user_prompt"] = self._param.user_prompt
  114. return m
  115. def get_input_form(self) -> dict[str, dict]:
  116. res = {}
  117. for k, v in self.get_input_elements().items():
  118. res[k] = {
  119. "type": "line",
  120. "name": v["name"]
  121. }
  122. for cpn in self._param.tools:
  123. if not isinstance(cpn, LLM):
  124. continue
  125. res.update(cpn.get_input_form())
  126. return res
  127. @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))
  128. def _invoke(self, **kwargs):
  129. if kwargs.get("user_prompt"):
  130. usr_pmt = ""
  131. if kwargs.get("reasoning"):
  132. usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
  133. if kwargs.get("context"):
  134. usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
  135. if usr_pmt:
  136. usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
  137. else:
  138. usr_pmt = str(kwargs["user_prompt"])
  139. self._param.prompts = [{"role": "user", "content": usr_pmt}]
  140. if not self.tools:
  141. return LLM._invoke(self, **kwargs)
  142. prompt, msg = self._prepare_prompt_variables()
  143. downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
  144. ex = self.exception_handler()
  145. if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
  146. self.set_output("content", partial(self.stream_output_with_tools, prompt, msg))
  147. return
  148. _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
  149. use_tools = []
  150. ans = ""
  151. for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools):
  152. ans += delta_ans
  153. if ans.find("**ERROR**") >= 0:
  154. logging.error(f"Agent._chat got error. response: {ans}")
  155. if self.get_exception_default_value():
  156. self.set_output("content", self.get_exception_default_value())
  157. else:
  158. self.set_output("_ERROR", ans)
  159. return
  160. self.set_output("content", ans)
  161. if use_tools:
  162. self.set_output("use_tools", use_tools)
  163. return ans
  164. def stream_output_with_tools(self, prompt, msg):
  165. _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
  166. answer_without_toolcall = ""
  167. use_tools = []
  168. for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools):
  169. if delta_ans.find("**ERROR**") >= 0:
  170. if self.get_exception_default_value():
  171. self.set_output("content", self.get_exception_default_value())
  172. yield self.get_exception_default_value()
  173. else:
  174. self.set_output("_ERROR", delta_ans)
  175. answer_without_toolcall += delta_ans
  176. yield delta_ans
  177. self.set_output("content", answer_without_toolcall)
  178. if use_tools:
  179. self.set_output("use_tools", use_tools)
  180. def _gen_citations(self, text):
  181. retrievals = self._canvas.get_reference()
  182. retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
  183. formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
  184. for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
  185. {"role": "user", "content": text}
  186. ]):
  187. yield delta_ans
  188. def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools):
  189. token_count = 0
  190. tool_metas = self.tool_meta
  191. hist = deepcopy(history)
  192. last_calling = ""
  193. if len(hist) > 3:
  194. user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
  195. self.callback("Multi-turn conversation optimization", {}, user_request)
  196. else:
  197. user_request = history[-1]["content"]
  198. def use_tool(name, args):
  199. nonlocal hist, use_tools, token_count,last_calling,user_request
  200. logging.info(f"{last_calling=} == {name=}")
  201. # Summarize of function calling
  202. #if all([
  203. # isinstance(self.toolcall_session.get_tool_obj(name), Agent),
  204. # last_calling,
  205. # last_calling != name
  206. #]):
  207. # self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"])))
  208. last_calling = name
  209. tool_response = self.toolcall_session.tool_call(name, args)
  210. use_tools.append({
  211. "name": name,
  212. "arguments": args,
  213. "results": tool_response
  214. })
  215. # self.callback("add_memory", {}, "...")
  216. #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response))
  217. return name, tool_response
  218. def complete():
  219. nonlocal hist
  220. need2cite = self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
  221. cited = False
  222. if hist[0]["role"] == "system" and need2cite:
  223. if len(hist) < 7:
  224. hist[0]["content"] += citation_prompt()
  225. cited = True
  226. yield "", token_count
  227. _hist = hist
  228. if len(hist) > 12:
  229. _hist = [hist[0], hist[1], *hist[-10:]]
  230. entire_txt = ""
  231. for delta_ans in self._generate_streamly(_hist):
  232. if not need2cite or cited:
  233. yield delta_ans, 0
  234. entire_txt += delta_ans
  235. if not need2cite or cited:
  236. return
  237. txt = ""
  238. for delta_ans in self._gen_citations(entire_txt):
  239. yield delta_ans, 0
  240. txt += delta_ans
  241. self.callback("gen_citations", {}, txt)
  242. def append_user_content(hist, content):
  243. if hist[-1]["role"] == "user":
  244. hist[-1]["content"] += content
  245. else:
  246. hist.append({"role": "user", "content": content})
  247. task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas)
  248. self.callback("analyze_task", {}, task_desc)
  249. for _ in range(self._param.max_rounds + 1):
  250. response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
  251. # self.callback("next_step", {}, str(response)[:256]+"...")
  252. token_count += tk
  253. hist.append({"role": "assistant", "content": response})
  254. try:
  255. functions = json_repair.loads(re.sub(r"```.*", "", response))
  256. if not isinstance(functions, list):
  257. raise TypeError(f"List should be returned, but `{functions}`")
  258. for f in functions:
  259. if not isinstance(f, dict):
  260. raise TypeError(f"An object type should be returned, but `{f}`")
  261. with ThreadPoolExecutor(max_workers=5) as executor:
  262. thr = []
  263. for func in functions:
  264. name = func["name"]
  265. args = func["arguments"]
  266. if name == COMPLETE_TASK:
  267. append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
  268. for txt, tkcnt in complete():
  269. yield txt, tkcnt
  270. return
  271. thr.append(executor.submit(use_tool, name, args))
  272. reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr])
  273. append_user_content(hist, reflection)
  274. self.callback("reflection", {}, str(reflection))
  275. except Exception as e:
  276. logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
  277. e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}"
  278. append_user_content(hist, str(e))
  279. logging.warning( f"Exceed max rounds: {self._param.max_rounds}")
  280. final_instruction = f"""
  281. {user_request}
  282. IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request.
  283. Instructions:
  284. 1. SYNTHESIZE all information collected during this conversation
  285. 2. Provide a COMPLETE response using existing data - do not suggest additional research
  286. 3. Structure your response as a FINAL DELIVERABLE, not a plan
  287. 4. If information is incomplete, state what you found and provide the best analysis possible with available data
  288. 5. DO NOT mention conversation limits or suggest further steps
  289. 6. Focus on delivering VALUE with the information already gathered
  290. Respond immediately with your final comprehensive answer.
  291. """
  292. append_user_content(hist, final_instruction)
  293. for txt, tkcnt in complete():
  294. yield txt, tkcnt
  295. def get_useful_memory(self, goal: str, sub_goal:str, topn=3) -> str:
  296. # self.callback("get_useful_memory", {"topn": 3}, "...")
  297. mems = self._canvas.get_memory()
  298. rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems])
  299. try:
  300. rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
  301. mems = [mems[r] for r in rank]
  302. return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
  303. except Exception as e:
  304. logging.exception(e)
  305. return "Error occurred."