# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import os import re from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from functools import partial from typing import Any import json_repair from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.mcp_server_service import MCPServerService from api.utils.api_utils import timeout from rag.prompts import message_fit_in from rag.prompts.prompts import next_step, COMPLETE_TASK, analyze_task, \ citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool from agent.component.llm import LLMParam, LLM class AgentParam(LLMParam, ToolParamBase): """ Define the Agent component parameters. """ def __init__(self): self.meta:ToolMeta = { "name": "agent", "description": "This is an agent for a specific task.", "parameters": { "user_prompt": { "type": "string", "description": "This is the order you need to send to the agent.", "default": "", "required": True }, "reasoning": { "type": "string", "description": ( "Supervisor's reasoning for choosing the this agent. " "Explain why this agent is being invoked and what is expected of it." ), "required": True }, "context": { "type": "string", "description": ( "All relevant background information, prior facts, decisions, " "and state needed by the agent to solve the current query. " "Should be as detailed and self-contained as possible." ), "required": True }, } } super().__init__() self.function_name = "agent" self.tools = [] self.mcp = [] self.max_rounds = 5 self.description = "" class Agent(LLM, ToolBase): component_name = "Agent" def __init__(self, canvas, id, param: LLMParam): LLM.__init__(self, canvas, id, param) self.tools = {} for cpn in self._param.tools: cpn = self._load_tool_obj(cpn) self.tools[cpn.get_meta()["function"]["name"]] = cpn self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id, max_retries=self._param.max_retries, retry_interval=self._param.delay_after_error, max_rounds=self._param.max_rounds, verbose_tool_use=True ) self.tool_meta = [v.get_meta() for _,v in self.tools.items()] for mcp in self._param.mcp: _, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"]) tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) for tnm, meta in mcp["tools"].items(): self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta)) self.tools[tnm] = tool_call_session self.callback = partial(self._canvas.tool_use_callback, id) self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback) #self.chat_mdl.bind_tools(self.toolcall_session, self.tool_metas) def _load_tool_obj(self, cpn: dict) -> object: from agent.component import component_class param = component_class(cpn["component_name"] + "Param")() param.update(cpn["params"]) try: param.check() except Exception as e: self.set_output("_ERROR", cpn["component_name"] + f" configuration error: {e}") raise cpn_id = f"{self._id}-->" + cpn.get("name", "").replace(" ", "_") return component_class(cpn["component_name"])(self._canvas, cpn_id, param) def get_meta(self) -> dict[str, Any]: self._param.function_name= self._id.split("-->")[-1] m = super().get_meta() if hasattr(self._param, "user_prompt") and self._param.user_prompt: m["function"]["parameters"]["properties"]["user_prompt"] = self._param.user_prompt return m def get_input_form(self) -> dict[str, dict]: res = {} for k, v in self.get_input_elements().items(): res[k] = { "type": "line", "name": v["name"] } for cpn in self._param.tools: if not isinstance(cpn, LLM): continue res.update(cpn.get_input_form()) return res @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)) def _invoke(self, **kwargs): if kwargs.get("user_prompt"): usr_pmt = "" if kwargs.get("reasoning"): usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"]) if kwargs.get("context"): usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"]) if usr_pmt: usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"])) else: usr_pmt = str(kwargs["user_prompt"]) self._param.prompts = [{"role": "user", "content": usr_pmt}] if not self.tools: return LLM._invoke(self, **kwargs) prompt, msg = self._prepare_prompt_variables() downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] ex = self.exception_handler() if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]): self.set_output("content", partial(self.stream_output_with_tools, prompt, msg)) return _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) use_tools = [] ans = "" for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools): ans += delta_ans if ans.find("**ERROR**") >= 0: logging.error(f"Agent._chat got error. response: {ans}") if self.get_exception_default_value(): self.set_output("content", self.get_exception_default_value()) else: self.set_output("_ERROR", ans) return self.set_output("content", ans) if use_tools: self.set_output("use_tools", use_tools) return ans def stream_output_with_tools(self, prompt, msg): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) answer_without_toolcall = "" use_tools = [] for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools): if delta_ans.find("**ERROR**") >= 0: if self.get_exception_default_value(): self.set_output("content", self.get_exception_default_value()) yield self.get_exception_default_value() else: self.set_output("_ERROR", delta_ans) answer_without_toolcall += delta_ans yield delta_ans self.set_output("content", answer_without_toolcall) if use_tools: self.set_output("use_tools", use_tools) def _gen_citations(self, text): retrievals = self._canvas.get_reference() retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())} formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True) for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))}, {"role": "user", "content": text} ]): yield delta_ans def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools): token_count = 0 tool_metas = self.tool_meta hist = deepcopy(history) last_calling = "" if len(hist) > 3: user_request = full_question(messages=history, chat_mdl=self.chat_mdl) self.callback("Multi-turn conversation optimization", {}, user_request) else: user_request = history[-1]["content"] def use_tool(name, args): nonlocal hist, use_tools, token_count,last_calling,user_request logging.info(f"{last_calling=} == {name=}") # Summarize of function calling #if all([ # isinstance(self.toolcall_session.get_tool_obj(name), Agent), # last_calling, # last_calling != name #]): # self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]))) last_calling = name tool_response = self.toolcall_session.tool_call(name, args) use_tools.append({ "name": name, "arguments": args, "results": tool_response }) # self.callback("add_memory", {}, "...") #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response)) return name, tool_response def complete(): nonlocal hist need2cite = self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0 cited = False if hist[0]["role"] == "system" and need2cite: if len(hist) < 7: hist[0]["content"] += citation_prompt() cited = True yield "", token_count _hist = hist if len(hist) > 12: _hist = [hist[0], hist[1], *hist[-10:]] entire_txt = "" for delta_ans in self._generate_streamly(_hist): if not need2cite or cited: yield delta_ans, 0 entire_txt += delta_ans if not need2cite or cited: return txt = "" for delta_ans in self._gen_citations(entire_txt): yield delta_ans, 0 txt += delta_ans self.callback("gen_citations", {}, txt) def append_user_content(hist, content): if hist[-1]["role"] == "user": hist[-1]["content"] += content else: hist.append({"role": "user", "content": content}) task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas) self.callback("analyze_task", {}, task_desc) for _ in range(self._param.max_rounds + 1): response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc) # self.callback("next_step", {}, str(response)[:256]+"...") token_count += tk hist.append({"role": "assistant", "content": response}) try: functions = json_repair.loads(re.sub(r"```.*", "", response)) if not isinstance(functions, list): raise TypeError(f"List should be returned, but `{functions}`") for f in functions: if not isinstance(f, dict): raise TypeError(f"An object type should be returned, but `{f}`") with ThreadPoolExecutor(max_workers=5) as executor: thr = [] for func in functions: name = func["name"] args = func["arguments"] if name == COMPLETE_TASK: append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n") for txt, tkcnt in complete(): yield txt, tkcnt return thr.append(executor.submit(use_tool, name, args)) reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr]) append_user_content(hist, reflection) self.callback("reflection", {}, str(reflection)) except Exception as e: logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}") e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}" append_user_content(hist, str(e)) logging.warning( f"Exceed max rounds: {self._param.max_rounds}") final_instruction = f""" {user_request} IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request. Instructions: 1. SYNTHESIZE all information collected during this conversation 2. Provide a COMPLETE response using existing data - do not suggest additional research 3. Structure your response as a FINAL DELIVERABLE, not a plan 4. If information is incomplete, state what you found and provide the best analysis possible with available data 5. DO NOT mention conversation limits or suggest further steps 6. Focus on delivering VALUE with the information already gathered Respond immediately with your final comprehensive answer. """ append_user_content(hist, final_instruction) for txt, tkcnt in complete(): yield txt, tkcnt def get_useful_memory(self, goal: str, sub_goal:str, topn=3) -> str: # self.callback("get_useful_memory", {"topn": 3}, "...") mems = self._canvas.get_memory() rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems]) try: rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn] mems = [mems[r] for r in rank] return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems]) except Exception as e: logging.exception(e) return "Error occurred."