Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import re
  19. from concurrent.futures import ThreadPoolExecutor
  20. from copy import deepcopy
  21. from functools import partial
  22. from typing import Any
  23. import json_repair
  24. from timeit import default_timer as timer
  25. from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
  26. from api.db.services.llm_service import LLMBundle
  27. from api.db.services.tenant_llm_service import TenantLLMService
  28. from api.db.services.mcp_server_service import MCPServerService
  29. from api.utils.api_utils import timeout
  30. from rag.prompts import message_fit_in
  31. from rag.prompts.prompts import next_step, COMPLETE_TASK, analyze_task, \
  32. citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question
  33. from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
  34. from agent.component.llm import LLMParam, LLM
  35. class AgentParam(LLMParam, ToolParamBase):
  36. """
  37. Define the Agent component parameters.
  38. """
  39. def __init__(self):
  40. self.meta:ToolMeta = {
  41. "name": "agent",
  42. "description": "This is an agent for a specific task.",
  43. "parameters": {
  44. "user_prompt": {
  45. "type": "string",
  46. "description": "This is the order you need to send to the agent.",
  47. "default": "",
  48. "required": True
  49. },
  50. "reasoning": {
  51. "type": "string",
  52. "description": (
  53. "Supervisor's reasoning for choosing the this agent. "
  54. "Explain why this agent is being invoked and what is expected of it."
  55. ),
  56. "required": True
  57. },
  58. "context": {
  59. "type": "string",
  60. "description": (
  61. "All relevant background information, prior facts, decisions, "
  62. "and state needed by the agent to solve the current query. "
  63. "Should be as detailed and self-contained as possible."
  64. ),
  65. "required": True
  66. },
  67. }
  68. }
  69. super().__init__()
  70. self.function_name = "agent"
  71. self.tools = []
  72. self.mcp = []
  73. self.max_rounds = 5
  74. self.description = ""
  75. class Agent(LLM, ToolBase):
  76. component_name = "Agent"
  77. def __init__(self, canvas, id, param: LLMParam):
  78. LLM.__init__(self, canvas, id, param)
  79. self.tools = {}
  80. for cpn in self._param.tools:
  81. cpn = self._load_tool_obj(cpn)
  82. self.tools[cpn.get_meta()["function"]["name"]] = cpn
  83. self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
  84. max_retries=self._param.max_retries,
  85. retry_interval=self._param.delay_after_error,
  86. max_rounds=self._param.max_rounds,
  87. verbose_tool_use=True
  88. )
  89. self.tool_meta = [v.get_meta() for _,v in self.tools.items()]
  90. for mcp in self._param.mcp:
  91. _, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
  92. tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
  93. for tnm, meta in mcp["tools"].items():
  94. self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta))
  95. self.tools[tnm] = tool_call_session
  96. self.callback = partial(self._canvas.tool_use_callback, id)
  97. self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback)
  98. #self.chat_mdl.bind_tools(self.toolcall_session, self.tool_metas)
  99. def _load_tool_obj(self, cpn: dict) -> object:
  100. from agent.component import component_class
  101. param = component_class(cpn["component_name"] + "Param")()
  102. param.update(cpn["params"])
  103. try:
  104. param.check()
  105. except Exception as e:
  106. self.set_output("_ERROR", cpn["component_name"] + f" configuration error: {e}")
  107. raise
  108. cpn_id = f"{self._id}-->" + cpn.get("name", "").replace(" ", "_")
  109. return component_class(cpn["component_name"])(self._canvas, cpn_id, param)
  110. def get_meta(self) -> dict[str, Any]:
  111. self._param.function_name= self._id.split("-->")[-1]
  112. m = super().get_meta()
  113. if hasattr(self._param, "user_prompt") and self._param.user_prompt:
  114. m["function"]["parameters"]["properties"]["user_prompt"] = self._param.user_prompt
  115. return m
  116. def get_input_form(self) -> dict[str, dict]:
  117. res = {}
  118. for k, v in self.get_input_elements().items():
  119. res[k] = {
  120. "type": "line",
  121. "name": v["name"]
  122. }
  123. for cpn in self._param.tools:
  124. if not isinstance(cpn, LLM):
  125. continue
  126. res.update(cpn.get_input_form())
  127. return res
  128. @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))
  129. def _invoke(self, **kwargs):
  130. if kwargs.get("user_prompt"):
  131. usr_pmt = ""
  132. if kwargs.get("reasoning"):
  133. usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
  134. if kwargs.get("context"):
  135. usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
  136. if usr_pmt:
  137. usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
  138. else:
  139. usr_pmt = str(kwargs["user_prompt"])
  140. self._param.prompts = [{"role": "user", "content": usr_pmt}]
  141. if not self.tools:
  142. return LLM._invoke(self, **kwargs)
  143. prompt, msg = self._prepare_prompt_variables()
  144. downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
  145. ex = self.exception_handler()
  146. if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
  147. self.set_output("content", partial(self.stream_output_with_tools, prompt, msg))
  148. return
  149. _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
  150. use_tools = []
  151. ans = ""
  152. for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools):
  153. ans += delta_ans
  154. if ans.find("**ERROR**") >= 0:
  155. logging.error(f"Agent._chat got error. response: {ans}")
  156. if self.get_exception_default_value():
  157. self.set_output("content", self.get_exception_default_value())
  158. else:
  159. self.set_output("_ERROR", ans)
  160. return
  161. self.set_output("content", ans)
  162. if use_tools:
  163. self.set_output("use_tools", use_tools)
  164. return ans
  165. def stream_output_with_tools(self, prompt, msg):
  166. _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
  167. answer_without_toolcall = ""
  168. use_tools = []
  169. for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools):
  170. if delta_ans.find("**ERROR**") >= 0:
  171. if self.get_exception_default_value():
  172. self.set_output("content", self.get_exception_default_value())
  173. yield self.get_exception_default_value()
  174. else:
  175. self.set_output("_ERROR", delta_ans)
  176. answer_without_toolcall += delta_ans
  177. yield delta_ans
  178. self.set_output("content", answer_without_toolcall)
  179. if use_tools:
  180. self.set_output("use_tools", use_tools)
  181. def _gen_citations(self, text):
  182. retrievals = self._canvas.get_reference()
  183. retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
  184. formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
  185. for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
  186. {"role": "user", "content": text}
  187. ]):
  188. yield delta_ans
  189. def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools):
  190. token_count = 0
  191. tool_metas = self.tool_meta
  192. hist = deepcopy(history)
  193. last_calling = ""
  194. if len(hist) > 3:
  195. st = timer()
  196. user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
  197. self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
  198. else:
  199. user_request = history[-1]["content"]
  200. def use_tool(name, args):
  201. nonlocal hist, use_tools, token_count,last_calling,user_request
  202. logging.info(f"{last_calling=} == {name=}")
  203. # Summarize of function calling
  204. #if all([
  205. # isinstance(self.toolcall_session.get_tool_obj(name), Agent),
  206. # last_calling,
  207. # last_calling != name
  208. #]):
  209. # self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"])))
  210. last_calling = name
  211. tool_response = self.toolcall_session.tool_call(name, args)
  212. use_tools.append({
  213. "name": name,
  214. "arguments": args,
  215. "results": tool_response
  216. })
  217. # self.callback("add_memory", {}, "...")
  218. #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response))
  219. return name, tool_response
  220. def complete():
  221. nonlocal hist
  222. need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
  223. cited = False
  224. if hist[0]["role"] == "system" and need2cite:
  225. if len(hist) < 7:
  226. hist[0]["content"] += citation_prompt()
  227. cited = True
  228. yield "", token_count
  229. _hist = hist
  230. if len(hist) > 12:
  231. _hist = [hist[0], hist[1], *hist[-10:]]
  232. entire_txt = ""
  233. for delta_ans in self._generate_streamly(_hist):
  234. if not need2cite or cited:
  235. yield delta_ans, 0
  236. entire_txt += delta_ans
  237. if not need2cite or cited:
  238. return
  239. st = timer()
  240. txt = ""
  241. for delta_ans in self._gen_citations(entire_txt):
  242. yield delta_ans, 0
  243. txt += delta_ans
  244. self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
  245. def append_user_content(hist, content):
  246. if hist[-1]["role"] == "user":
  247. hist[-1]["content"] += content
  248. else:
  249. hist.append({"role": "user", "content": content})
  250. st = timer()
  251. task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas)
  252. self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
  253. for _ in range(self._param.max_rounds + 1):
  254. response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
  255. # self.callback("next_step", {}, str(response)[:256]+"...")
  256. token_count += tk
  257. hist.append({"role": "assistant", "content": response})
  258. try:
  259. functions = json_repair.loads(re.sub(r"```.*", "", response))
  260. if not isinstance(functions, list):
  261. raise TypeError(f"List should be returned, but `{functions}`")
  262. for f in functions:
  263. if not isinstance(f, dict):
  264. raise TypeError(f"An object type should be returned, but `{f}`")
  265. with ThreadPoolExecutor(max_workers=5) as executor:
  266. thr = []
  267. for func in functions:
  268. name = func["name"]
  269. args = func["arguments"]
  270. if name == COMPLETE_TASK:
  271. append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
  272. for txt, tkcnt in complete():
  273. yield txt, tkcnt
  274. return
  275. thr.append(executor.submit(use_tool, name, args))
  276. st = timer()
  277. reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr])
  278. append_user_content(hist, reflection)
  279. self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
  280. except Exception as e:
  281. logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
  282. e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}"
  283. append_user_content(hist, str(e))
  284. logging.warning( f"Exceed max rounds: {self._param.max_rounds}")
  285. final_instruction = f"""
  286. {user_request}
  287. IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request.
  288. Instructions:
  289. 1. SYNTHESIZE all information collected during this conversation
  290. 2. Provide a COMPLETE response using existing data - do not suggest additional research
  291. 3. Structure your response as a FINAL DELIVERABLE, not a plan
  292. 4. If information is incomplete, state what you found and provide the best analysis possible with available data
  293. 5. DO NOT mention conversation limits or suggest further steps
  294. 6. Focus on delivering VALUE with the information already gathered
  295. Respond immediately with your final comprehensive answer.
  296. """
  297. append_user_content(hist, final_instruction)
  298. for txt, tkcnt in complete():
  299. yield txt, tkcnt
  300. def get_useful_memory(self, goal: str, sub_goal:str, topn=3) -> str:
  301. # self.callback("get_useful_memory", {"topn": 3}, "...")
  302. mems = self._canvas.get_memory()
  303. rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems])
  304. try:
  305. rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
  306. mems = [mems[r] for r in rank]
  307. return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
  308. except Exception as e:
  309. logging.exception(e)
  310. return "Error occurred."