您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

fc_agent_runner.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from copy import deepcopy
  5. from typing import Any, Optional, Union
  6. from core.agent.base_agent_runner import BaseAgentRunner
  7. from core.app.apps.base_app_queue_manager import PublishFrom
  8. from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
  9. from core.file import file_manager
  10. from core.model_runtime.entities import (
  11. AssistantPromptMessage,
  12. LLMResult,
  13. LLMResultChunk,
  14. LLMResultChunkDelta,
  15. LLMUsage,
  16. PromptMessage,
  17. PromptMessageContentType,
  18. SystemPromptMessage,
  19. TextPromptMessageContent,
  20. ToolPromptMessage,
  21. UserPromptMessage,
  22. )
  23. from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
  24. from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
  25. from core.tools.entities.tool_entities import ToolInvokeMeta
  26. from core.tools.tool_engine import ToolEngine
  27. from models.model import Message
  28. logger = logging.getLogger(__name__)
  29. class FunctionCallAgentRunner(BaseAgentRunner):
  30. def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
  31. """
  32. Run FunctionCall agent application
  33. """
  34. self.query = query
  35. app_generate_entity = self.application_generate_entity
  36. app_config = self.app_config
  37. assert app_config is not None, "app_config is required"
  38. assert app_config.agent is not None, "app_config.agent is required"
  39. # convert tools into ModelRuntime Tool format
  40. tool_instances, prompt_messages_tools = self._init_prompt_tools()
  41. # fix metadata filter not work
  42. if app_config.dataset is not None:
  43. metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
  44. for key, dataset_retriever_tool in tool_instances.items():
  45. if hasattr(dataset_retriever_tool, "retrieval_tool"):
  46. dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
  47. assert app_config.agent
  48. iteration_step = 1
  49. max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
  50. # continue to run until there is not any tool call
  51. function_call_state = True
  52. llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
  53. final_answer = ""
  54. # get tracing instance
  55. trace_manager = app_generate_entity.trace_manager
  56. def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
  57. if not final_llm_usage_dict["usage"]:
  58. final_llm_usage_dict["usage"] = usage
  59. else:
  60. llm_usage = final_llm_usage_dict["usage"]
  61. llm_usage.prompt_tokens += usage.prompt_tokens
  62. llm_usage.completion_tokens += usage.completion_tokens
  63. llm_usage.prompt_price += usage.prompt_price
  64. llm_usage.completion_price += usage.completion_price
  65. llm_usage.total_price += usage.total_price
  66. model_instance = self.model_instance
  67. while function_call_state and iteration_step <= max_iteration_steps:
  68. function_call_state = False
  69. if iteration_step == max_iteration_steps:
  70. # the last iteration, remove all tools
  71. prompt_messages_tools = []
  72. message_file_ids: list[str] = []
  73. agent_thought = self.create_agent_thought(
  74. message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
  75. )
  76. # recalc llm max tokens
  77. prompt_messages = self._organize_prompt_messages()
  78. self.recalc_llm_max_tokens(self.model_config, prompt_messages)
  79. # invoke model
  80. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  81. prompt_messages=prompt_messages,
  82. model_parameters=app_generate_entity.model_conf.parameters,
  83. tools=prompt_messages_tools,
  84. stop=app_generate_entity.model_conf.stop,
  85. stream=self.stream_tool_call,
  86. user=self.user_id,
  87. callbacks=[],
  88. )
  89. tool_calls: list[tuple[str, str, dict[str, Any]]] = []
  90. # save full response
  91. response = ""
  92. # save tool call names and inputs
  93. tool_call_names = ""
  94. tool_call_inputs = ""
  95. current_llm_usage = None
  96. if isinstance(chunks, Generator):
  97. is_first_chunk = True
  98. for chunk in chunks:
  99. if is_first_chunk:
  100. self.queue_manager.publish(
  101. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  102. )
  103. is_first_chunk = False
  104. # check if there is any tool call
  105. if self.check_tool_calls(chunk):
  106. function_call_state = True
  107. tool_calls.extend(self.extract_tool_calls(chunk) or [])
  108. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  109. try:
  110. tool_call_inputs = json.dumps(
  111. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  112. )
  113. except json.JSONDecodeError:
  114. # ensure ascii to avoid encoding error
  115. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  116. if chunk.delta.message and chunk.delta.message.content:
  117. if isinstance(chunk.delta.message.content, list):
  118. for content in chunk.delta.message.content:
  119. response += content.data
  120. else:
  121. response += str(chunk.delta.message.content)
  122. if chunk.delta.usage:
  123. increase_usage(llm_usage, chunk.delta.usage)
  124. current_llm_usage = chunk.delta.usage
  125. yield chunk
  126. else:
  127. result = chunks
  128. # check if there is any tool call
  129. if self.check_blocking_tool_calls(result):
  130. function_call_state = True
  131. tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
  132. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  133. try:
  134. tool_call_inputs = json.dumps(
  135. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  136. )
  137. except json.JSONDecodeError:
  138. # ensure ascii to avoid encoding error
  139. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  140. if result.usage:
  141. increase_usage(llm_usage, result.usage)
  142. current_llm_usage = result.usage
  143. if result.message and result.message.content:
  144. if isinstance(result.message.content, list):
  145. for content in result.message.content:
  146. response += content.data
  147. else:
  148. response += str(result.message.content)
  149. if not result.message.content:
  150. result.message.content = ""
  151. self.queue_manager.publish(
  152. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  153. )
  154. yield LLMResultChunk(
  155. model=model_instance.model,
  156. prompt_messages=result.prompt_messages,
  157. system_fingerprint=result.system_fingerprint,
  158. delta=LLMResultChunkDelta(
  159. index=0,
  160. message=result.message,
  161. usage=result.usage,
  162. ),
  163. )
  164. assistant_message = AssistantPromptMessage(content="", tool_calls=[])
  165. if tool_calls:
  166. assistant_message.tool_calls = [
  167. AssistantPromptMessage.ToolCall(
  168. id=tool_call[0],
  169. type="function",
  170. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  171. name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
  172. ),
  173. )
  174. for tool_call in tool_calls
  175. ]
  176. else:
  177. assistant_message.content = response
  178. self._current_thoughts.append(assistant_message)
  179. # save thought
  180. self.save_agent_thought(
  181. agent_thought=agent_thought,
  182. tool_name=tool_call_names,
  183. tool_input=tool_call_inputs,
  184. thought=response,
  185. tool_invoke_meta=None,
  186. observation=None,
  187. answer=response,
  188. messages_ids=[],
  189. llm_usage=current_llm_usage,
  190. )
  191. self.queue_manager.publish(
  192. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  193. )
  194. final_answer += response + "\n"
  195. # call tools
  196. tool_responses = []
  197. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  198. tool_instance = tool_instances.get(tool_call_name)
  199. if not tool_instance:
  200. tool_response = {
  201. "tool_call_id": tool_call_id,
  202. "tool_call_name": tool_call_name,
  203. "tool_response": f"there is not a tool named {tool_call_name}",
  204. "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
  205. }
  206. else:
  207. # invoke tool
  208. tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
  209. tool=tool_instance,
  210. tool_parameters=tool_call_args,
  211. user_id=self.user_id,
  212. tenant_id=self.tenant_id,
  213. message=self.message,
  214. invoke_from=self.application_generate_entity.invoke_from,
  215. agent_tool_callback=self.agent_callback,
  216. trace_manager=trace_manager,
  217. app_id=self.application_generate_entity.app_config.app_id,
  218. message_id=self.message.id,
  219. conversation_id=self.conversation.id,
  220. )
  221. # publish files
  222. for message_file_id in message_files:
  223. # publish message file
  224. self.queue_manager.publish(
  225. QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
  226. )
  227. # add message file ids
  228. message_file_ids.append(message_file_id)
  229. tool_response = {
  230. "tool_call_id": tool_call_id,
  231. "tool_call_name": tool_call_name,
  232. "tool_response": tool_invoke_response,
  233. "meta": tool_invoke_meta.to_dict(),
  234. }
  235. tool_responses.append(tool_response)
  236. if tool_response["tool_response"] is not None:
  237. self._current_thoughts.append(
  238. ToolPromptMessage(
  239. content=str(tool_response["tool_response"]),
  240. tool_call_id=tool_call_id,
  241. name=tool_call_name,
  242. )
  243. )
  244. if len(tool_responses) > 0:
  245. # save agent thought
  246. self.save_agent_thought(
  247. agent_thought=agent_thought,
  248. tool_name="",
  249. tool_input="",
  250. thought="",
  251. tool_invoke_meta={
  252. tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
  253. },
  254. observation={
  255. tool_response["tool_call_name"]: tool_response["tool_response"]
  256. for tool_response in tool_responses
  257. },
  258. answer="",
  259. messages_ids=message_file_ids,
  260. )
  261. self.queue_manager.publish(
  262. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  263. )
  264. # update prompt tool
  265. for prompt_tool in prompt_messages_tools:
  266. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  267. iteration_step += 1
  268. # publish end event
  269. self.queue_manager.publish(
  270. QueueMessageEndEvent(
  271. llm_result=LLMResult(
  272. model=model_instance.model,
  273. prompt_messages=prompt_messages,
  274. message=AssistantPromptMessage(content=final_answer),
  275. usage=llm_usage["usage"] or LLMUsage.empty_usage(),
  276. system_fingerprint="",
  277. )
  278. ),
  279. PublishFrom.APPLICATION_MANAGER,
  280. )
  281. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  282. """
  283. Check if there is any tool call in llm result chunk
  284. """
  285. if llm_result_chunk.delta.message.tool_calls:
  286. return True
  287. return False
  288. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  289. """
  290. Check if there is any blocking tool call in llm result
  291. """
  292. if llm_result.message.tool_calls:
  293. return True
  294. return False
  295. def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
  296. """
  297. Extract tool calls from llm result chunk
  298. Returns:
  299. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  300. """
  301. tool_calls = []
  302. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  303. args = {}
  304. if prompt_message.function.arguments != "":
  305. args = json.loads(prompt_message.function.arguments)
  306. tool_calls.append(
  307. (
  308. prompt_message.id,
  309. prompt_message.function.name,
  310. args,
  311. )
  312. )
  313. return tool_calls
  314. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
  315. """
  316. Extract blocking tool calls from llm result
  317. Returns:
  318. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  319. """
  320. tool_calls = []
  321. for prompt_message in llm_result.message.tool_calls:
  322. args = {}
  323. if prompt_message.function.arguments != "":
  324. args = json.loads(prompt_message.function.arguments)
  325. tool_calls.append(
  326. (
  327. prompt_message.id,
  328. prompt_message.function.name,
  329. args,
  330. )
  331. )
  332. return tool_calls
  333. def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  334. """
  335. Initialize system message
  336. """
  337. if not prompt_messages and prompt_template:
  338. return [
  339. SystemPromptMessage(content=prompt_template),
  340. ]
  341. if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
  342. prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
  343. return prompt_messages or []
  344. def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  345. """
  346. Organize user query
  347. """
  348. if self.files:
  349. prompt_message_contents: list[PromptMessageContentUnionTypes] = []
  350. prompt_message_contents.append(TextPromptMessageContent(data=query))
  351. # get image detail config
  352. image_detail_config = (
  353. self.application_generate_entity.file_upload_config.image_config.detail
  354. if (
  355. self.application_generate_entity.file_upload_config
  356. and self.application_generate_entity.file_upload_config.image_config
  357. )
  358. else None
  359. )
  360. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  361. for file in self.files:
  362. prompt_message_contents.append(
  363. file_manager.to_prompt_message_content(
  364. file,
  365. image_detail_config=image_detail_config,
  366. )
  367. )
  368. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  369. else:
  370. prompt_messages.append(UserPromptMessage(content=query))
  371. return prompt_messages
  372. def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  373. """
  374. As for now, gpt supports both fc and vision at the first iteration.
  375. We need to remove the image messages from the prompt messages at the first iteration.
  376. """
  377. prompt_messages = deepcopy(prompt_messages)
  378. for prompt_message in prompt_messages:
  379. if isinstance(prompt_message, UserPromptMessage):
  380. if isinstance(prompt_message.content, list):
  381. prompt_message.content = "\n".join(
  382. [
  383. content.data
  384. if content.type == PromptMessageContentType.TEXT
  385. else "[image]"
  386. if content.type == PromptMessageContentType.IMAGE
  387. else "[file]"
  388. for content in prompt_message.content
  389. ]
  390. )
  391. return prompt_messages
  392. def _organize_prompt_messages(self):
  393. prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
  394. self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
  395. query_prompt_messages = self._organize_user_query(self.query or "", [])
  396. self.history_prompt_messages = AgentHistoryPromptTransform(
  397. model_config=self.model_config,
  398. prompt_messages=[*query_prompt_messages, *self._current_thoughts],
  399. history_messages=self.history_prompt_messages,
  400. memory=self.memory,
  401. ).get_prompt()
  402. prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
  403. if len(self._current_thoughts) != 0:
  404. # clear messages after the first iteration
  405. prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
  406. return prompt_messages