You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fc_agent_runner.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from copy import deepcopy
  5. from typing import Any, Optional, Union
  6. from core.agent.base_agent_runner import BaseAgentRunner
  7. from core.app.apps.base_app_queue_manager import PublishFrom
  8. from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
  9. from core.file import file_manager
  10. from core.model_runtime.entities import (
  11. AssistantPromptMessage,
  12. LLMResult,
  13. LLMResultChunk,
  14. LLMResultChunkDelta,
  15. LLMUsage,
  16. PromptMessage,
  17. PromptMessageContentType,
  18. SystemPromptMessage,
  19. TextPromptMessageContent,
  20. ToolPromptMessage,
  21. UserPromptMessage,
  22. )
  23. from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
  24. from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
  25. from core.tools.entities.tool_entities import ToolInvokeMeta
  26. from core.tools.tool_engine import ToolEngine
  27. from models.model import Message
  28. logger = logging.getLogger(__name__)
  29. class FunctionCallAgentRunner(BaseAgentRunner):
  30. def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
  31. """
  32. Run FunctionCall agent application
  33. """
  34. self.query = query
  35. app_generate_entity = self.application_generate_entity
  36. app_config = self.app_config
  37. assert app_config is not None, "app_config is required"
  38. assert app_config.agent is not None, "app_config.agent is required"
  39. # convert tools into ModelRuntime Tool format
  40. tool_instances, prompt_messages_tools = self._init_prompt_tools()
  41. assert app_config.agent
  42. iteration_step = 1
  43. max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
  44. # continue to run until there is not any tool call
  45. function_call_state = True
  46. llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
  47. final_answer = ""
  48. prompt_messages: list = [] # Initialize prompt_messages
  49. # get tracing instance
  50. trace_manager = app_generate_entity.trace_manager
  51. def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
  52. if not final_llm_usage_dict["usage"]:
  53. final_llm_usage_dict["usage"] = usage
  54. else:
  55. llm_usage = final_llm_usage_dict["usage"]
  56. llm_usage.prompt_tokens += usage.prompt_tokens
  57. llm_usage.completion_tokens += usage.completion_tokens
  58. llm_usage.total_tokens += usage.total_tokens
  59. llm_usage.prompt_price += usage.prompt_price
  60. llm_usage.completion_price += usage.completion_price
  61. llm_usage.total_price += usage.total_price
  62. model_instance = self.model_instance
  63. while function_call_state and iteration_step <= max_iteration_steps:
  64. function_call_state = False
  65. if iteration_step == max_iteration_steps:
  66. # the last iteration, remove all tools
  67. prompt_messages_tools = []
  68. message_file_ids: list[str] = []
  69. agent_thought_id = self.create_agent_thought(
  70. message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
  71. )
  72. # recalc llm max tokens
  73. prompt_messages = self._organize_prompt_messages()
  74. self.recalc_llm_max_tokens(self.model_config, prompt_messages)
  75. # invoke model
  76. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  77. prompt_messages=prompt_messages,
  78. model_parameters=app_generate_entity.model_conf.parameters,
  79. tools=prompt_messages_tools,
  80. stop=app_generate_entity.model_conf.stop,
  81. stream=self.stream_tool_call,
  82. user=self.user_id,
  83. callbacks=[],
  84. )
  85. tool_calls: list[tuple[str, str, dict[str, Any]]] = []
  86. # save full response
  87. response = ""
  88. # save tool call names and inputs
  89. tool_call_names = ""
  90. tool_call_inputs = ""
  91. current_llm_usage = None
  92. if isinstance(chunks, Generator):
  93. is_first_chunk = True
  94. for chunk in chunks:
  95. if is_first_chunk:
  96. self.queue_manager.publish(
  97. QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
  98. )
  99. is_first_chunk = False
  100. # check if there is any tool call
  101. if self.check_tool_calls(chunk):
  102. function_call_state = True
  103. tool_calls.extend(self.extract_tool_calls(chunk) or [])
  104. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  105. try:
  106. tool_call_inputs = json.dumps(
  107. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  108. )
  109. except TypeError:
  110. # fallback: force ASCII to handle non-serializable objects
  111. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  112. if chunk.delta.message and chunk.delta.message.content:
  113. if isinstance(chunk.delta.message.content, list):
  114. for content in chunk.delta.message.content:
  115. response += content.data
  116. else:
  117. response += str(chunk.delta.message.content)
  118. if chunk.delta.usage:
  119. increase_usage(llm_usage, chunk.delta.usage)
  120. current_llm_usage = chunk.delta.usage
  121. yield chunk
  122. else:
  123. result = chunks
  124. # check if there is any tool call
  125. if self.check_blocking_tool_calls(result):
  126. function_call_state = True
  127. tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
  128. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  129. try:
  130. tool_call_inputs = json.dumps(
  131. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  132. )
  133. except TypeError:
  134. # fallback: force ASCII to handle non-serializable objects
  135. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  136. if result.usage:
  137. increase_usage(llm_usage, result.usage)
  138. current_llm_usage = result.usage
  139. if result.message and result.message.content:
  140. if isinstance(result.message.content, list):
  141. for content in result.message.content:
  142. response += content.data
  143. else:
  144. response += str(result.message.content)
  145. if not result.message.content:
  146. result.message.content = ""
  147. self.queue_manager.publish(
  148. QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
  149. )
  150. yield LLMResultChunk(
  151. model=model_instance.model,
  152. prompt_messages=result.prompt_messages,
  153. system_fingerprint=result.system_fingerprint,
  154. delta=LLMResultChunkDelta(
  155. index=0,
  156. message=result.message,
  157. usage=result.usage,
  158. ),
  159. )
  160. assistant_message = AssistantPromptMessage(content="", tool_calls=[])
  161. if tool_calls:
  162. assistant_message.tool_calls = [
  163. AssistantPromptMessage.ToolCall(
  164. id=tool_call[0],
  165. type="function",
  166. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  167. name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
  168. ),
  169. )
  170. for tool_call in tool_calls
  171. ]
  172. else:
  173. assistant_message.content = response
  174. self._current_thoughts.append(assistant_message)
  175. # save thought
  176. self.save_agent_thought(
  177. agent_thought_id=agent_thought_id,
  178. tool_name=tool_call_names,
  179. tool_input=tool_call_inputs,
  180. thought=response,
  181. tool_invoke_meta=None,
  182. observation=None,
  183. answer=response,
  184. messages_ids=[],
  185. llm_usage=current_llm_usage,
  186. )
  187. self.queue_manager.publish(
  188. QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
  189. )
  190. final_answer += response + "\n"
  191. # call tools
  192. tool_responses = []
  193. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  194. tool_instance = tool_instances.get(tool_call_name)
  195. if not tool_instance:
  196. tool_response = {
  197. "tool_call_id": tool_call_id,
  198. "tool_call_name": tool_call_name,
  199. "tool_response": f"there is not a tool named {tool_call_name}",
  200. "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
  201. }
  202. else:
  203. # invoke tool
  204. tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
  205. tool=tool_instance,
  206. tool_parameters=tool_call_args,
  207. user_id=self.user_id,
  208. tenant_id=self.tenant_id,
  209. message=self.message,
  210. invoke_from=self.application_generate_entity.invoke_from,
  211. agent_tool_callback=self.agent_callback,
  212. trace_manager=trace_manager,
  213. app_id=self.application_generate_entity.app_config.app_id,
  214. message_id=self.message.id,
  215. conversation_id=self.conversation.id,
  216. )
  217. # publish files
  218. for message_file_id in message_files:
  219. # publish message file
  220. self.queue_manager.publish(
  221. QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
  222. )
  223. # add message file ids
  224. message_file_ids.append(message_file_id)
  225. tool_response = {
  226. "tool_call_id": tool_call_id,
  227. "tool_call_name": tool_call_name,
  228. "tool_response": tool_invoke_response,
  229. "meta": tool_invoke_meta.to_dict(),
  230. }
  231. tool_responses.append(tool_response)
  232. if tool_response["tool_response"] is not None:
  233. self._current_thoughts.append(
  234. ToolPromptMessage(
  235. content=str(tool_response["tool_response"]),
  236. tool_call_id=tool_call_id,
  237. name=tool_call_name,
  238. )
  239. )
  240. if len(tool_responses) > 0:
  241. # save agent thought
  242. self.save_agent_thought(
  243. agent_thought_id=agent_thought_id,
  244. tool_name="",
  245. tool_input="",
  246. thought="",
  247. tool_invoke_meta={
  248. tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
  249. },
  250. observation={
  251. tool_response["tool_call_name"]: tool_response["tool_response"]
  252. for tool_response in tool_responses
  253. },
  254. answer="",
  255. messages_ids=message_file_ids,
  256. )
  257. self.queue_manager.publish(
  258. QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
  259. )
  260. # update prompt tool
  261. for prompt_tool in prompt_messages_tools:
  262. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  263. iteration_step += 1
  264. # publish end event
  265. self.queue_manager.publish(
  266. QueueMessageEndEvent(
  267. llm_result=LLMResult(
  268. model=model_instance.model,
  269. prompt_messages=prompt_messages,
  270. message=AssistantPromptMessage(content=final_answer),
  271. usage=llm_usage["usage"] or LLMUsage.empty_usage(),
  272. system_fingerprint="",
  273. )
  274. ),
  275. PublishFrom.APPLICATION_MANAGER,
  276. )
  277. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  278. """
  279. Check if there is any tool call in llm result chunk
  280. """
  281. if llm_result_chunk.delta.message.tool_calls:
  282. return True
  283. return False
  284. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  285. """
  286. Check if there is any blocking tool call in llm result
  287. """
  288. if llm_result.message.tool_calls:
  289. return True
  290. return False
  291. def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
  292. """
  293. Extract tool calls from llm result chunk
  294. Returns:
  295. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  296. """
  297. tool_calls = []
  298. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  299. args = {}
  300. if prompt_message.function.arguments != "":
  301. args = json.loads(prompt_message.function.arguments)
  302. tool_calls.append(
  303. (
  304. prompt_message.id,
  305. prompt_message.function.name,
  306. args,
  307. )
  308. )
  309. return tool_calls
  310. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
  311. """
  312. Extract blocking tool calls from llm result
  313. Returns:
  314. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  315. """
  316. tool_calls = []
  317. for prompt_message in llm_result.message.tool_calls:
  318. args = {}
  319. if prompt_message.function.arguments != "":
  320. args = json.loads(prompt_message.function.arguments)
  321. tool_calls.append(
  322. (
  323. prompt_message.id,
  324. prompt_message.function.name,
  325. args,
  326. )
  327. )
  328. return tool_calls
  329. def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  330. """
  331. Initialize system message
  332. """
  333. if not prompt_messages and prompt_template:
  334. return [
  335. SystemPromptMessage(content=prompt_template),
  336. ]
  337. if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
  338. prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
  339. return prompt_messages or []
  340. def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  341. """
  342. Organize user query
  343. """
  344. if self.files:
  345. # get image detail config
  346. image_detail_config = (
  347. self.application_generate_entity.file_upload_config.image_config.detail
  348. if (
  349. self.application_generate_entity.file_upload_config
  350. and self.application_generate_entity.file_upload_config.image_config
  351. )
  352. else None
  353. )
  354. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  355. prompt_message_contents: list[PromptMessageContentUnionTypes] = []
  356. for file in self.files:
  357. prompt_message_contents.append(
  358. file_manager.to_prompt_message_content(
  359. file,
  360. image_detail_config=image_detail_config,
  361. )
  362. )
  363. prompt_message_contents.append(TextPromptMessageContent(data=query))
  364. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  365. else:
  366. prompt_messages.append(UserPromptMessage(content=query))
  367. return prompt_messages
  368. def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  369. """
  370. As for now, gpt supports both fc and vision at the first iteration.
  371. We need to remove the image messages from the prompt messages at the first iteration.
  372. """
  373. prompt_messages = deepcopy(prompt_messages)
  374. for prompt_message in prompt_messages:
  375. if isinstance(prompt_message, UserPromptMessage):
  376. if isinstance(prompt_message.content, list):
  377. prompt_message.content = "\n".join(
  378. [
  379. content.data
  380. if content.type == PromptMessageContentType.TEXT
  381. else "[image]"
  382. if content.type == PromptMessageContentType.IMAGE
  383. else "[file]"
  384. for content in prompt_message.content
  385. ]
  386. )
  387. return prompt_messages
  388. def _organize_prompt_messages(self):
  389. prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
  390. self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
  391. query_prompt_messages = self._organize_user_query(self.query or "", [])
  392. self.history_prompt_messages = AgentHistoryPromptTransform(
  393. model_config=self.model_config,
  394. prompt_messages=[*query_prompt_messages, *self._current_thoughts],
  395. history_messages=self.history_prompt_messages,
  396. memory=self.memory,
  397. ).get_prompt()
  398. prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
  399. if len(self._current_thoughts) != 0:
  400. # clear messages after the first iteration
  401. prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
  402. return prompt_messages