Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

agent_with_tools.py 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import re
  19. from concurrent.futures import ThreadPoolExecutor
  20. from copy import deepcopy
  21. from functools import partial
  22. from typing import Any
  23. import json_repair
  24. from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
  25. from api.db.services.llm_service import LLMBundle, TenantLLMService
  26. from api.db.services.mcp_server_service import MCPServerService
  27. from api.utils.api_utils import timeout
  28. from rag.prompts import message_fit_in
  29. from rag.prompts.prompts import next_step, COMPLETE_TASK, analyze_task, \
  30. citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question
  31. from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
  32. from agent.component.llm import LLMParam, LLM
  33. class AgentParam(LLMParam, ToolParamBase):
  34. """
  35. Define the Agent component parameters.
  36. """
  37. def __init__(self):
  38. self.meta:ToolMeta = {
  39. "name": "agent",
  40. "description": "This is an agent for a specific task.",
  41. "parameters": {
  42. "user_prompt": {
  43. "type": "string",
  44. "description": "This is the order you need to send to the agent.",
  45. "default": "",
  46. "required": True
  47. },
  48. "reasoning": {
  49. "type": "string",
  50. "description": (
  51. "Supervisor's reasoning for choosing the this agent. "
  52. "Explain why this agent is being invoked and what is expected of it."
  53. ),
  54. "required": True
  55. },
  56. "context": {
  57. "type": "string",
  58. "description": (
  59. "All relevant background information, prior facts, decisions, "
  60. "and state needed by the agent to solve the current query. "
  61. "Should be as detailed and self-contained as possible."
  62. ),
  63. "required": True
  64. },
  65. }
  66. }
  67. super().__init__()
  68. self.function_name = "agent"
  69. self.tools = []
  70. self.mcp = []
  71. self.max_rounds = 5
  72. self.description = ""
  73. class Agent(LLM, ToolBase):
  74. component_name = "Agent"
  75. def __init__(self, canvas, id, param: LLMParam):
  76. LLM.__init__(self, canvas, id, param)
  77. self.tools = {}
  78. for cpn in self._param.tools:
  79. cpn = self._load_tool_obj(cpn)
  80. self.tools[cpn.get_meta()["function"]["name"]] = cpn
  81. self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
  82. max_retries=self._param.max_retries,
  83. retry_interval=self._param.delay_after_error,
  84. max_rounds=self._param.max_rounds,
  85. verbose_tool_use=True
  86. )
  87. self.tool_meta = [v.get_meta() for _,v in self.tools.items()]
  88. for mcp in self._param.mcp:
  89. _, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
  90. tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
  91. for tnm, meta in mcp["tools"].items():
  92. self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta))
  93. self.tools[tnm] = tool_call_session
  94. self.callback = partial(self._canvas.tool_use_callback, id)
  95. self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback)
  96. #self.chat_mdl.bind_tools(self.toolcall_session, self.tool_metas)
  97. def _load_tool_obj(self, cpn: dict) -> object:
  98. from agent.component import component_class
  99. param = component_class(cpn["component_name"] + "Param")()
  100. param.update(cpn["params"])
  101. try:
  102. param.check()
  103. except Exception as e:
  104. self.set_output("_ERROR", cpn["component_name"] + f" configuration error: {e}")
  105. raise
  106. cpn_id = f"{self._id}-->" + cpn.get("name", "").replace(" ", "_")
  107. return component_class(cpn["component_name"])(self._canvas, cpn_id, param)
  108. def get_meta(self) -> dict[str, Any]:
  109. self._param.function_name= self._id.split("-->")[-1]
  110. m = super().get_meta()
  111. if hasattr(self._param, "user_prompt") and self._param.user_prompt:
  112. m["function"]["parameters"]["properties"]["user_prompt"] = self._param.user_prompt
  113. return m
  114. def get_input_form(self) -> dict[str, dict]:
  115. res = {}
  116. for k, v in self.get_input_elements().items():
  117. res[k] = {
  118. "type": "line",
  119. "name": v["name"]
  120. }
  121. for cpn in self._param.tools:
  122. if not isinstance(cpn, LLM):
  123. continue
  124. res.update(cpn.get_input_form())
  125. return res
  126. @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))
  127. def _invoke(self, **kwargs):
  128. if kwargs.get("user_prompt"):
  129. usr_pmt = ""
  130. if kwargs.get("reasoning"):
  131. usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
  132. if kwargs.get("context"):
  133. usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
  134. if usr_pmt:
  135. usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
  136. else:
  137. usr_pmt = str(kwargs["user_prompt"])
  138. self._param.prompts = [{"role": "user", "content": usr_pmt}]
  139. if not self.tools:
  140. return LLM._invoke(self, **kwargs)
  141. prompt, msg = self._prepare_prompt_variables()
  142. downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
  143. ex = self.exception_handler()
  144. if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
  145. self.set_output("content", partial(self.stream_output_with_tools, prompt, msg))
  146. return
  147. _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
  148. use_tools = []
  149. ans = ""
  150. for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools):
  151. ans += delta_ans
  152. if ans.find("**ERROR**") >= 0:
  153. logging.error(f"Agent._chat got error. response: {ans}")
  154. if self.get_exception_default_value():
  155. self.set_output("content", self.get_exception_default_value())
  156. else:
  157. self.set_output("_ERROR", ans)
  158. return
  159. self.set_output("content", ans)
  160. if use_tools:
  161. self.set_output("use_tools", use_tools)
  162. return ans
  163. def stream_output_with_tools(self, prompt, msg):
  164. _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
  165. answer_without_toolcall = ""
  166. use_tools = []
  167. for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools):
  168. if delta_ans.find("**ERROR**") >= 0:
  169. if self.get_exception_default_value():
  170. self.set_output("content", self.get_exception_default_value())
  171. yield self.get_exception_default_value()
  172. else:
  173. self.set_output("_ERROR", delta_ans)
  174. answer_without_toolcall += delta_ans
  175. yield delta_ans
  176. self.set_output("content", answer_without_toolcall)
  177. if use_tools:
  178. self.set_output("use_tools", use_tools)
  179. def _gen_citations(self, text):
  180. retrievals = self._canvas.get_reference()
  181. retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
  182. formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
  183. for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
  184. {"role": "user", "content": text}
  185. ]):
  186. yield delta_ans
  187. def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools):
  188. token_count = 0
  189. tool_metas = self.tool_meta
  190. hist = deepcopy(history)
  191. last_calling = ""
  192. if len(hist) > 3:
  193. user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
  194. self.callback("Multi-turn conversation optimization", {}, user_request)
  195. else:
  196. user_request = history[-1]["content"]
  197. def use_tool(name, args):
  198. nonlocal hist, use_tools, token_count,last_calling,user_request
  199. logging.info(f"{last_calling=} == {name=}")
  200. # Summarize of function calling
  201. #if all([
  202. # isinstance(self.toolcall_session.get_tool_obj(name), Agent),
  203. # last_calling,
  204. # last_calling != name
  205. #]):
  206. # self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"])))
  207. last_calling = name
  208. tool_response = self.toolcall_session.tool_call(name, args)
  209. use_tools.append({
  210. "name": name,
  211. "arguments": args,
  212. "results": tool_response
  213. })
  214. # self.callback("add_memory", {}, "...")
  215. #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response))
  216. return name, tool_response
  217. def complete():
  218. nonlocal hist
  219. need2cite = self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
  220. cited = False
  221. if hist[0]["role"] == "system" and need2cite:
  222. if len(hist) < 7:
  223. hist[0]["content"] += citation_prompt()
  224. cited = True
  225. yield "", token_count
  226. _hist = hist
  227. if len(hist) > 12:
  228. _hist = [hist[0], hist[1], *hist[-10:]]
  229. entire_txt = ""
  230. for delta_ans in self._generate_streamly(_hist):
  231. if not need2cite or cited:
  232. yield delta_ans, 0
  233. entire_txt += delta_ans
  234. if not need2cite or cited:
  235. return
  236. txt = ""
  237. for delta_ans in self._gen_citations(entire_txt):
  238. yield delta_ans, 0
  239. txt += delta_ans
  240. self.callback("gen_citations", {}, txt)
  241. def append_user_content(hist, content):
  242. if hist[-1]["role"] == "user":
  243. hist[-1]["content"] += content
  244. else:
  245. hist.append({"role": "user", "content": content})
  246. task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas)
  247. self.callback("analyze_task", {}, task_desc)
  248. for _ in range(self._param.max_rounds + 1):
  249. response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
  250. # self.callback("next_step", {}, str(response)[:256]+"...")
  251. token_count += tk
  252. hist.append({"role": "assistant", "content": response})
  253. try:
  254. functions = json_repair.loads(re.sub(r"```.*", "", response))
  255. if not isinstance(functions, list):
  256. raise TypeError(f"List should be returned, but `{functions}`")
  257. for f in functions:
  258. if not isinstance(f, dict):
  259. raise TypeError(f"An object type should be returned, but `{f}`")
  260. with ThreadPoolExecutor(max_workers=5) as executor:
  261. thr = []
  262. for func in functions:
  263. name = func["name"]
  264. args = func["arguments"]
  265. if name == COMPLETE_TASK:
  266. append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
  267. for txt, tkcnt in complete():
  268. yield txt, tkcnt
  269. return
  270. thr.append(executor.submit(use_tool, name, args))
  271. reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr])
  272. append_user_content(hist, reflection)
  273. self.callback("reflection", {}, str(reflection))
  274. except Exception as e:
  275. logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
  276. e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}"
  277. append_user_content(hist, str(e))
  278. logging.warning( f"Exceed max rounds: {self._param.max_rounds}")
  279. final_instruction = f"""
  280. {user_request}
  281. IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request.
  282. Instructions:
  283. 1. SYNTHESIZE all information collected during this conversation
  284. 2. Provide a COMPLETE response using existing data - do not suggest additional research
  285. 3. Structure your response as a FINAL DELIVERABLE, not a plan
  286. 4. If information is incomplete, state what you found and provide the best analysis possible with available data
  287. 5. DO NOT mention conversation limits or suggest further steps
  288. 6. Focus on delivering VALUE with the information already gathered
  289. Respond immediately with your final comprehensive answer.
  290. """
  291. append_user_content(hist, final_instruction)
  292. for txt, tkcnt in complete():
  293. yield txt, tkcnt
  294. def get_useful_memory(self, goal: str, sub_goal:str, topn=3) -> str:
  295. # self.callback("get_useful_memory", {"topn": 3}, "...")
  296. mems = self._canvas.get_memory()
  297. rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems])
  298. try:
  299. rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
  300. mems = [mems[r] for r in rank]
  301. return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
  302. except Exception as e:
  303. logging.exception(e)
  304. return "Error occurred."