You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_agent_runner.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. import json
  2. import logging
  3. import uuid
  4. from typing import Optional, Union, cast
  5. from core.agent.entities import AgentEntity, AgentToolEntity
  6. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  7. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  8. from core.app.apps.base_app_queue_manager import AppQueueManager
  9. from core.app.apps.base_app_runner import AppRunner
  10. from core.app.entities.app_invoke_entities import (
  11. AgentChatAppGenerateEntity,
  12. ModelConfigWithCredentialsEntity,
  13. )
  14. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  15. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  16. from core.file import file_manager
  17. from core.memory.token_buffer_memory import TokenBufferMemory
  18. from core.model_manager import ModelInstance
  19. from core.model_runtime.entities import (
  20. AssistantPromptMessage,
  21. LLMUsage,
  22. PromptMessage,
  23. PromptMessageTool,
  24. SystemPromptMessage,
  25. TextPromptMessageContent,
  26. ToolPromptMessage,
  27. UserPromptMessage,
  28. )
  29. from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
  30. from core.model_runtime.entities.model_entities import ModelFeature
  31. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  32. from core.prompt.utils.extract_thread_messages import extract_thread_messages
  33. from core.tools.__base.tool import Tool
  34. from core.tools.entities.tool_entities import (
  35. ToolParameter,
  36. )
  37. from core.tools.tool_manager import ToolManager
  38. from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
  39. from extensions.ext_database import db
  40. from factories import file_factory
  41. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  42. logger = logging.getLogger(__name__)
  43. class BaseAgentRunner(AppRunner):
  44. def __init__(
  45. self,
  46. *,
  47. tenant_id: str,
  48. application_generate_entity: AgentChatAppGenerateEntity,
  49. conversation: Conversation,
  50. app_config: AgentChatAppConfig,
  51. model_config: ModelConfigWithCredentialsEntity,
  52. config: AgentEntity,
  53. queue_manager: AppQueueManager,
  54. message: Message,
  55. user_id: str,
  56. model_instance: ModelInstance,
  57. memory: Optional[TokenBufferMemory] = None,
  58. prompt_messages: Optional[list[PromptMessage]] = None,
  59. ) -> None:
  60. self.tenant_id = tenant_id
  61. self.application_generate_entity = application_generate_entity
  62. self.conversation = conversation
  63. self.app_config = app_config
  64. self.model_config = model_config
  65. self.config = config
  66. self.queue_manager = queue_manager
  67. self.message = message
  68. self.user_id = user_id
  69. self.memory = memory
  70. self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
  71. self.model_instance = model_instance
  72. # init callback
  73. self.agent_callback = DifyAgentCallbackHandler()
  74. # init dataset tools
  75. hit_callback = DatasetIndexToolCallbackHandler(
  76. queue_manager=queue_manager,
  77. app_id=self.app_config.app_id,
  78. message_id=message.id,
  79. user_id=user_id,
  80. invoke_from=self.application_generate_entity.invoke_from,
  81. )
  82. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  83. tenant_id=tenant_id,
  84. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  85. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  86. return_resource=app_config.additional_features.show_retrieve_source,
  87. invoke_from=application_generate_entity.invoke_from,
  88. hit_callback=hit_callback,
  89. )
  90. # get how many agent thoughts have been created
  91. self.agent_thought_count = (
  92. db.session.query(MessageAgentThought)
  93. .filter(
  94. MessageAgentThought.message_id == self.message.id,
  95. )
  96. .count()
  97. )
  98. db.session.close()
  99. # check if model supports stream tool call
  100. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  101. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  102. features = model_schema.features if model_schema and model_schema.features else []
  103. self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
  104. self.files = application_generate_entity.files if ModelFeature.VISION in features else []
  105. self.query: Optional[str] = ""
  106. self._current_thoughts: list[PromptMessage] = []
  107. def _repack_app_generate_entity(
  108. self, app_generate_entity: AgentChatAppGenerateEntity
  109. ) -> AgentChatAppGenerateEntity:
  110. """
  111. Repack app generate entity
  112. """
  113. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  114. app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
  115. return app_generate_entity
  116. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  117. """
  118. convert tool to prompt message tool
  119. """
  120. tool_entity = ToolManager.get_agent_tool_runtime(
  121. tenant_id=self.tenant_id,
  122. app_id=self.app_config.app_id,
  123. agent_tool=tool,
  124. invoke_from=self.application_generate_entity.invoke_from,
  125. )
  126. assert tool_entity.entity.description
  127. message_tool = PromptMessageTool(
  128. name=tool.tool_name,
  129. description=tool_entity.entity.description.llm,
  130. parameters={
  131. "type": "object",
  132. "properties": {},
  133. "required": [],
  134. },
  135. )
  136. parameters = tool_entity.get_merged_runtime_parameters()
  137. for parameter in parameters:
  138. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  139. continue
  140. parameter_type = parameter.type.as_normal_type()
  141. if parameter.type in {
  142. ToolParameter.ToolParameterType.SYSTEM_FILES,
  143. ToolParameter.ToolParameterType.FILE,
  144. ToolParameter.ToolParameterType.FILES,
  145. }:
  146. continue
  147. enum = []
  148. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  149. enum = [option.value for option in parameter.options] if parameter.options else []
  150. message_tool.parameters["properties"][parameter.name] = {
  151. "type": parameter_type,
  152. "description": parameter.llm_description or "",
  153. }
  154. if len(enum) > 0:
  155. message_tool.parameters["properties"][parameter.name]["enum"] = enum
  156. if parameter.required:
  157. message_tool.parameters["required"].append(parameter.name)
  158. return message_tool, tool_entity
  159. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  160. """
  161. convert dataset retriever tool to prompt message tool
  162. """
  163. assert tool.entity.description
  164. prompt_tool = PromptMessageTool(
  165. name=tool.entity.identity.name,
  166. description=tool.entity.description.llm,
  167. parameters={
  168. "type": "object",
  169. "properties": {},
  170. "required": [],
  171. },
  172. )
  173. for parameter in tool.get_runtime_parameters():
  174. parameter_type = "string"
  175. prompt_tool.parameters["properties"][parameter.name] = {
  176. "type": parameter_type,
  177. "description": parameter.llm_description or "",
  178. }
  179. if parameter.required:
  180. if parameter.name not in prompt_tool.parameters["required"]:
  181. prompt_tool.parameters["required"].append(parameter.name)
  182. return prompt_tool
  183. def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
  184. """
  185. Init tools
  186. """
  187. tool_instances = {}
  188. prompt_messages_tools = []
  189. for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
  190. try:
  191. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  192. except Exception:
  193. # api tool may be deleted
  194. continue
  195. # save tool entity
  196. tool_instances[tool.tool_name] = tool_entity
  197. # save prompt tool
  198. prompt_messages_tools.append(prompt_tool)
  199. # convert dataset tools into ModelRuntime Tool format
  200. for dataset_tool in self.dataset_tools:
  201. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  202. # save prompt tool
  203. prompt_messages_tools.append(prompt_tool)
  204. # save tool entity
  205. tool_instances[dataset_tool.entity.identity.name] = dataset_tool
  206. return tool_instances, prompt_messages_tools
  207. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  208. """
  209. update prompt message tool
  210. """
  211. # try to get tool runtime parameters
  212. tool_runtime_parameters = tool.get_runtime_parameters()
  213. for parameter in tool_runtime_parameters:
  214. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  215. continue
  216. parameter_type = parameter.type.as_normal_type()
  217. if parameter.type in {
  218. ToolParameter.ToolParameterType.SYSTEM_FILES,
  219. ToolParameter.ToolParameterType.FILE,
  220. ToolParameter.ToolParameterType.FILES,
  221. }:
  222. continue
  223. enum = []
  224. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  225. enum = [option.value for option in parameter.options] if parameter.options else []
  226. prompt_tool.parameters["properties"][parameter.name] = {
  227. "type": parameter_type,
  228. "description": parameter.llm_description or "",
  229. }
  230. if len(enum) > 0:
  231. prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
  232. if parameter.required:
  233. if parameter.name not in prompt_tool.parameters["required"]:
  234. prompt_tool.parameters["required"].append(parameter.name)
  235. return prompt_tool
  236. def create_agent_thought(
  237. self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
  238. ) -> MessageAgentThought:
  239. """
  240. Create agent thought
  241. """
  242. thought = MessageAgentThought(
  243. message_id=message_id,
  244. message_chain_id=None,
  245. thought="",
  246. tool=tool_name,
  247. tool_labels_str="{}",
  248. tool_meta_str="{}",
  249. tool_input=tool_input,
  250. message=message,
  251. message_token=0,
  252. message_unit_price=0,
  253. message_price_unit=0,
  254. message_files=json.dumps(messages_ids) if messages_ids else "",
  255. answer="",
  256. observation="",
  257. answer_token=0,
  258. answer_unit_price=0,
  259. answer_price_unit=0,
  260. tokens=0,
  261. total_price=0,
  262. position=self.agent_thought_count + 1,
  263. currency="USD",
  264. latency=0,
  265. created_by_role="account",
  266. created_by=self.user_id,
  267. )
  268. db.session.add(thought)
  269. db.session.commit()
  270. db.session.refresh(thought)
  271. db.session.close()
  272. self.agent_thought_count += 1
  273. return thought
  274. def save_agent_thought(
  275. self,
  276. agent_thought: MessageAgentThought,
  277. tool_name: str | None,
  278. tool_input: Union[str, dict, None],
  279. thought: str | None,
  280. observation: Union[str, dict, None],
  281. tool_invoke_meta: Union[str, dict, None],
  282. answer: str | None,
  283. messages_ids: list[str],
  284. llm_usage: LLMUsage | None = None,
  285. ):
  286. """
  287. Save agent thought
  288. """
  289. updated_agent_thought = (
  290. db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
  291. )
  292. if not updated_agent_thought:
  293. raise ValueError("agent thought not found")
  294. agent_thought = updated_agent_thought
  295. if thought:
  296. agent_thought.thought += thought
  297. if tool_name:
  298. agent_thought.tool = tool_name
  299. if tool_input:
  300. if isinstance(tool_input, dict):
  301. try:
  302. tool_input = json.dumps(tool_input, ensure_ascii=False)
  303. except Exception:
  304. tool_input = json.dumps(tool_input)
  305. updated_agent_thought.tool_input = tool_input
  306. if observation:
  307. if isinstance(observation, dict):
  308. try:
  309. observation = json.dumps(observation, ensure_ascii=False)
  310. except Exception:
  311. observation = json.dumps(observation)
  312. updated_agent_thought.observation = observation
  313. if answer:
  314. agent_thought.answer = answer
  315. if messages_ids is not None and len(messages_ids) > 0:
  316. updated_agent_thought.message_files = json.dumps(messages_ids)
  317. if llm_usage:
  318. updated_agent_thought.message_token = llm_usage.prompt_tokens
  319. updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
  320. updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
  321. updated_agent_thought.answer_token = llm_usage.completion_tokens
  322. updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
  323. updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
  324. updated_agent_thought.tokens = llm_usage.total_tokens
  325. updated_agent_thought.total_price = llm_usage.total_price
  326. # check if tool labels is not empty
  327. labels = updated_agent_thought.tool_labels or {}
  328. tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
  329. for tool in tools:
  330. if not tool:
  331. continue
  332. if tool not in labels:
  333. tool_label = ToolManager.get_tool_label(tool)
  334. if tool_label:
  335. labels[tool] = tool_label.to_dict()
  336. else:
  337. labels[tool] = {"en_US": tool, "zh_Hans": tool}
  338. updated_agent_thought.tool_labels_str = json.dumps(labels)
  339. if tool_invoke_meta is not None:
  340. if isinstance(tool_invoke_meta, dict):
  341. try:
  342. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  343. except Exception:
  344. tool_invoke_meta = json.dumps(tool_invoke_meta)
  345. updated_agent_thought.tool_meta_str = tool_invoke_meta
  346. db.session.commit()
  347. db.session.close()
  348. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  349. """
  350. Organize agent history
  351. """
  352. result: list[PromptMessage] = []
  353. # check if there is a system message in the beginning of the conversation
  354. for prompt_message in prompt_messages:
  355. if isinstance(prompt_message, SystemPromptMessage):
  356. result.append(prompt_message)
  357. messages: list[Message] = (
  358. db.session.query(Message)
  359. .filter(
  360. Message.conversation_id == self.message.conversation_id,
  361. )
  362. .order_by(Message.created_at.desc())
  363. .all()
  364. )
  365. messages = list(reversed(extract_thread_messages(messages)))
  366. for message in messages:
  367. if message.id == self.message.id:
  368. continue
  369. result.append(self.organize_agent_user_prompt(message))
  370. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  371. if agent_thoughts:
  372. for agent_thought in agent_thoughts:
  373. tools = agent_thought.tool
  374. if tools:
  375. tools = tools.split(";")
  376. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  377. tool_call_response: list[ToolPromptMessage] = []
  378. try:
  379. tool_inputs = json.loads(agent_thought.tool_input)
  380. except Exception:
  381. tool_inputs = {tool: {} for tool in tools}
  382. try:
  383. tool_responses = json.loads(agent_thought.observation)
  384. except Exception:
  385. tool_responses = dict.fromkeys(tools, agent_thought.observation)
  386. for tool in tools:
  387. # generate a uuid for tool call
  388. tool_call_id = str(uuid.uuid4())
  389. tool_calls.append(
  390. AssistantPromptMessage.ToolCall(
  391. id=tool_call_id,
  392. type="function",
  393. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  394. name=tool,
  395. arguments=json.dumps(tool_inputs.get(tool, {})),
  396. ),
  397. )
  398. )
  399. tool_call_response.append(
  400. ToolPromptMessage(
  401. content=tool_responses.get(tool, agent_thought.observation),
  402. name=tool,
  403. tool_call_id=tool_call_id,
  404. )
  405. )
  406. result.extend(
  407. [
  408. AssistantPromptMessage(
  409. content=agent_thought.thought,
  410. tool_calls=tool_calls,
  411. ),
  412. *tool_call_response,
  413. ]
  414. )
  415. if not tools:
  416. result.append(AssistantPromptMessage(content=agent_thought.thought))
  417. else:
  418. if message.answer:
  419. result.append(AssistantPromptMessage(content=message.answer))
  420. db.session.close()
  421. return result
  422. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  423. files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
  424. if not files:
  425. return UserPromptMessage(content=message.query)
  426. if message.app_model_config:
  427. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  428. else:
  429. file_extra_config = None
  430. if not file_extra_config:
  431. return UserPromptMessage(content=message.query)
  432. image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
  433. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  434. file_objs = file_factory.build_from_message_files(
  435. message_files=files, tenant_id=self.tenant_id, config=file_extra_config
  436. )
  437. if not file_objs:
  438. return UserPromptMessage(content=message.query)
  439. prompt_message_contents: list[PromptMessageContentUnionTypes] = []
  440. prompt_message_contents.append(TextPromptMessageContent(data=message.query))
  441. for file in file_objs:
  442. prompt_message_contents.append(
  443. file_manager.to_prompt_message_content(
  444. file,
  445. image_detail_config=image_detail_config,
  446. )
  447. )
  448. return UserPromptMessage(content=prompt_message_contents)