You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_agent_runner.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. import json
  2. import logging
  3. import uuid
  4. from typing import Optional, Union, cast
  5. from core.agent.entities import AgentEntity, AgentToolEntity
  6. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  7. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  8. from core.app.apps.base_app_queue_manager import AppQueueManager
  9. from core.app.apps.base_app_runner import AppRunner
  10. from core.app.entities.app_invoke_entities import (
  11. AgentChatAppGenerateEntity,
  12. ModelConfigWithCredentialsEntity,
  13. )
  14. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  15. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  16. from core.file import file_manager
  17. from core.memory.token_buffer_memory import TokenBufferMemory
  18. from core.model_manager import ModelInstance
  19. from core.model_runtime.entities import (
  20. AssistantPromptMessage,
  21. LLMUsage,
  22. PromptMessage,
  23. PromptMessageTool,
  24. SystemPromptMessage,
  25. TextPromptMessageContent,
  26. ToolPromptMessage,
  27. UserPromptMessage,
  28. )
  29. from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
  30. from core.model_runtime.entities.model_entities import ModelFeature
  31. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  32. from core.prompt.utils.extract_thread_messages import extract_thread_messages
  33. from core.tools.__base.tool import Tool
  34. from core.tools.entities.tool_entities import (
  35. ToolParameter,
  36. )
  37. from core.tools.tool_manager import ToolManager
  38. from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
  39. from extensions.ext_database import db
  40. from factories import file_factory
  41. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  42. logger = logging.getLogger(__name__)
  43. class BaseAgentRunner(AppRunner):
  44. def __init__(
  45. self,
  46. *,
  47. tenant_id: str,
  48. application_generate_entity: AgentChatAppGenerateEntity,
  49. conversation: Conversation,
  50. app_config: AgentChatAppConfig,
  51. model_config: ModelConfigWithCredentialsEntity,
  52. config: AgentEntity,
  53. queue_manager: AppQueueManager,
  54. message: Message,
  55. user_id: str,
  56. model_instance: ModelInstance,
  57. memory: Optional[TokenBufferMemory] = None,
  58. prompt_messages: Optional[list[PromptMessage]] = None,
  59. ) -> None:
  60. self.tenant_id = tenant_id
  61. self.application_generate_entity = application_generate_entity
  62. self.conversation = conversation
  63. self.app_config = app_config
  64. self.model_config = model_config
  65. self.config = config
  66. self.queue_manager = queue_manager
  67. self.message = message
  68. self.user_id = user_id
  69. self.memory = memory
  70. self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
  71. self.model_instance = model_instance
  72. # init callback
  73. self.agent_callback = DifyAgentCallbackHandler()
  74. # init dataset tools
  75. hit_callback = DatasetIndexToolCallbackHandler(
  76. queue_manager=queue_manager,
  77. app_id=self.app_config.app_id,
  78. message_id=message.id,
  79. user_id=user_id,
  80. invoke_from=self.application_generate_entity.invoke_from,
  81. )
  82. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  83. tenant_id=tenant_id,
  84. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  85. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  86. return_resource=app_config.additional_features.show_retrieve_source,
  87. invoke_from=application_generate_entity.invoke_from,
  88. hit_callback=hit_callback,
  89. user_id=user_id,
  90. inputs=cast(dict, application_generate_entity.inputs),
  91. )
  92. # get how many agent thoughts have been created
  93. self.agent_thought_count = (
  94. db.session.query(MessageAgentThought)
  95. .filter(
  96. MessageAgentThought.message_id == self.message.id,
  97. )
  98. .count()
  99. )
  100. db.session.close()
  101. # check if model supports stream tool call
  102. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  103. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  104. features = model_schema.features if model_schema and model_schema.features else []
  105. self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
  106. self.files = application_generate_entity.files if ModelFeature.VISION in features else []
  107. self.query: Optional[str] = ""
  108. self._current_thoughts: list[PromptMessage] = []
  109. def _repack_app_generate_entity(
  110. self, app_generate_entity: AgentChatAppGenerateEntity
  111. ) -> AgentChatAppGenerateEntity:
  112. """
  113. Repack app generate entity
  114. """
  115. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  116. app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
  117. return app_generate_entity
  118. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  119. """
  120. convert tool to prompt message tool
  121. """
  122. tool_entity = ToolManager.get_agent_tool_runtime(
  123. tenant_id=self.tenant_id,
  124. app_id=self.app_config.app_id,
  125. agent_tool=tool,
  126. invoke_from=self.application_generate_entity.invoke_from,
  127. )
  128. assert tool_entity.entity.description
  129. message_tool = PromptMessageTool(
  130. name=tool.tool_name,
  131. description=tool_entity.entity.description.llm,
  132. parameters={
  133. "type": "object",
  134. "properties": {},
  135. "required": [],
  136. },
  137. )
  138. parameters = tool_entity.get_merged_runtime_parameters()
  139. for parameter in parameters:
  140. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  141. continue
  142. parameter_type = parameter.type.as_normal_type()
  143. if parameter.type in {
  144. ToolParameter.ToolParameterType.SYSTEM_FILES,
  145. ToolParameter.ToolParameterType.FILE,
  146. ToolParameter.ToolParameterType.FILES,
  147. }:
  148. continue
  149. enum = []
  150. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  151. enum = [option.value for option in parameter.options] if parameter.options else []
  152. message_tool.parameters["properties"][parameter.name] = {
  153. "type": parameter_type,
  154. "description": parameter.llm_description or "",
  155. }
  156. if len(enum) > 0:
  157. message_tool.parameters["properties"][parameter.name]["enum"] = enum
  158. if parameter.required:
  159. message_tool.parameters["required"].append(parameter.name)
  160. return message_tool, tool_entity
  161. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  162. """
  163. convert dataset retriever tool to prompt message tool
  164. """
  165. assert tool.entity.description
  166. prompt_tool = PromptMessageTool(
  167. name=tool.entity.identity.name,
  168. description=tool.entity.description.llm,
  169. parameters={
  170. "type": "object",
  171. "properties": {},
  172. "required": [],
  173. },
  174. )
  175. for parameter in tool.get_runtime_parameters():
  176. parameter_type = "string"
  177. prompt_tool.parameters["properties"][parameter.name] = {
  178. "type": parameter_type,
  179. "description": parameter.llm_description or "",
  180. }
  181. if parameter.required:
  182. if parameter.name not in prompt_tool.parameters["required"]:
  183. prompt_tool.parameters["required"].append(parameter.name)
  184. return prompt_tool
  185. def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
  186. """
  187. Init tools
  188. """
  189. tool_instances = {}
  190. prompt_messages_tools = []
  191. for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
  192. try:
  193. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  194. except Exception:
  195. # api tool may be deleted
  196. continue
  197. # save tool entity
  198. tool_instances[tool.tool_name] = tool_entity
  199. # save prompt tool
  200. prompt_messages_tools.append(prompt_tool)
  201. # convert dataset tools into ModelRuntime Tool format
  202. for dataset_tool in self.dataset_tools:
  203. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  204. # save prompt tool
  205. prompt_messages_tools.append(prompt_tool)
  206. # save tool entity
  207. tool_instances[dataset_tool.entity.identity.name] = dataset_tool
  208. return tool_instances, prompt_messages_tools
  209. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  210. """
  211. update prompt message tool
  212. """
  213. # try to get tool runtime parameters
  214. tool_runtime_parameters = tool.get_runtime_parameters()
  215. for parameter in tool_runtime_parameters:
  216. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  217. continue
  218. parameter_type = parameter.type.as_normal_type()
  219. if parameter.type in {
  220. ToolParameter.ToolParameterType.SYSTEM_FILES,
  221. ToolParameter.ToolParameterType.FILE,
  222. ToolParameter.ToolParameterType.FILES,
  223. }:
  224. continue
  225. enum = []
  226. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  227. enum = [option.value for option in parameter.options] if parameter.options else []
  228. prompt_tool.parameters["properties"][parameter.name] = {
  229. "type": parameter_type,
  230. "description": parameter.llm_description or "",
  231. }
  232. if len(enum) > 0:
  233. prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
  234. if parameter.required:
  235. if parameter.name not in prompt_tool.parameters["required"]:
  236. prompt_tool.parameters["required"].append(parameter.name)
  237. return prompt_tool
  238. def create_agent_thought(
  239. self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
  240. ) -> MessageAgentThought:
  241. """
  242. Create agent thought
  243. """
  244. thought = MessageAgentThought(
  245. message_id=message_id,
  246. message_chain_id=None,
  247. thought="",
  248. tool=tool_name,
  249. tool_labels_str="{}",
  250. tool_meta_str="{}",
  251. tool_input=tool_input,
  252. message=message,
  253. message_token=0,
  254. message_unit_price=0,
  255. message_price_unit=0,
  256. message_files=json.dumps(messages_ids) if messages_ids else "",
  257. answer="",
  258. observation="",
  259. answer_token=0,
  260. answer_unit_price=0,
  261. answer_price_unit=0,
  262. tokens=0,
  263. total_price=0,
  264. position=self.agent_thought_count + 1,
  265. currency="USD",
  266. latency=0,
  267. created_by_role="account",
  268. created_by=self.user_id,
  269. )
  270. db.session.add(thought)
  271. db.session.commit()
  272. db.session.refresh(thought)
  273. db.session.close()
  274. self.agent_thought_count += 1
  275. return thought
  276. def save_agent_thought(
  277. self,
  278. agent_thought: MessageAgentThought,
  279. tool_name: str | None,
  280. tool_input: Union[str, dict, None],
  281. thought: str | None,
  282. observation: Union[str, dict, None],
  283. tool_invoke_meta: Union[str, dict, None],
  284. answer: str | None,
  285. messages_ids: list[str],
  286. llm_usage: LLMUsage | None = None,
  287. ):
  288. """
  289. Save agent thought
  290. """
  291. updated_agent_thought = (
  292. db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
  293. )
  294. if not updated_agent_thought:
  295. raise ValueError("agent thought not found")
  296. agent_thought = updated_agent_thought
  297. if thought:
  298. agent_thought.thought += thought
  299. if tool_name:
  300. agent_thought.tool = tool_name
  301. if tool_input:
  302. if isinstance(tool_input, dict):
  303. try:
  304. tool_input = json.dumps(tool_input, ensure_ascii=False)
  305. except Exception:
  306. tool_input = json.dumps(tool_input)
  307. updated_agent_thought.tool_input = tool_input
  308. if observation:
  309. if isinstance(observation, dict):
  310. try:
  311. observation = json.dumps(observation, ensure_ascii=False)
  312. except Exception:
  313. observation = json.dumps(observation)
  314. updated_agent_thought.observation = observation
  315. if answer:
  316. agent_thought.answer = answer
  317. if messages_ids is not None and len(messages_ids) > 0:
  318. updated_agent_thought.message_files = json.dumps(messages_ids)
  319. if llm_usage:
  320. updated_agent_thought.message_token = llm_usage.prompt_tokens
  321. updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
  322. updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
  323. updated_agent_thought.answer_token = llm_usage.completion_tokens
  324. updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
  325. updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
  326. updated_agent_thought.tokens = llm_usage.total_tokens
  327. updated_agent_thought.total_price = llm_usage.total_price
  328. # check if tool labels is not empty
  329. labels = updated_agent_thought.tool_labels or {}
  330. tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
  331. for tool in tools:
  332. if not tool:
  333. continue
  334. if tool not in labels:
  335. tool_label = ToolManager.get_tool_label(tool)
  336. if tool_label:
  337. labels[tool] = tool_label.to_dict()
  338. else:
  339. labels[tool] = {"en_US": tool, "zh_Hans": tool}
  340. updated_agent_thought.tool_labels_str = json.dumps(labels)
  341. if tool_invoke_meta is not None:
  342. if isinstance(tool_invoke_meta, dict):
  343. try:
  344. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  345. except Exception:
  346. tool_invoke_meta = json.dumps(tool_invoke_meta)
  347. updated_agent_thought.tool_meta_str = tool_invoke_meta
  348. db.session.commit()
  349. db.session.close()
  350. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  351. """
  352. Organize agent history
  353. """
  354. result: list[PromptMessage] = []
  355. # check if there is a system message in the beginning of the conversation
  356. for prompt_message in prompt_messages:
  357. if isinstance(prompt_message, SystemPromptMessage):
  358. result.append(prompt_message)
  359. messages: list[Message] = (
  360. db.session.query(Message)
  361. .filter(
  362. Message.conversation_id == self.message.conversation_id,
  363. )
  364. .order_by(Message.created_at.desc())
  365. .all()
  366. )
  367. messages = list(reversed(extract_thread_messages(messages)))
  368. for message in messages:
  369. if message.id == self.message.id:
  370. continue
  371. result.append(self.organize_agent_user_prompt(message))
  372. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  373. if agent_thoughts:
  374. for agent_thought in agent_thoughts:
  375. tools = agent_thought.tool
  376. if tools:
  377. tools = tools.split(";")
  378. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  379. tool_call_response: list[ToolPromptMessage] = []
  380. try:
  381. tool_inputs = json.loads(agent_thought.tool_input)
  382. except Exception:
  383. tool_inputs = {tool: {} for tool in tools}
  384. try:
  385. tool_responses = json.loads(agent_thought.observation)
  386. except Exception:
  387. tool_responses = dict.fromkeys(tools, agent_thought.observation)
  388. for tool in tools:
  389. # generate a uuid for tool call
  390. tool_call_id = str(uuid.uuid4())
  391. tool_calls.append(
  392. AssistantPromptMessage.ToolCall(
  393. id=tool_call_id,
  394. type="function",
  395. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  396. name=tool,
  397. arguments=json.dumps(tool_inputs.get(tool, {})),
  398. ),
  399. )
  400. )
  401. tool_call_response.append(
  402. ToolPromptMessage(
  403. content=tool_responses.get(tool, agent_thought.observation),
  404. name=tool,
  405. tool_call_id=tool_call_id,
  406. )
  407. )
  408. result.extend(
  409. [
  410. AssistantPromptMessage(
  411. content=agent_thought.thought,
  412. tool_calls=tool_calls,
  413. ),
  414. *tool_call_response,
  415. ]
  416. )
  417. if not tools:
  418. result.append(AssistantPromptMessage(content=agent_thought.thought))
  419. else:
  420. if message.answer:
  421. result.append(AssistantPromptMessage(content=message.answer))
  422. db.session.close()
  423. return result
  424. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  425. files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
  426. if not files:
  427. return UserPromptMessage(content=message.query)
  428. if message.app_model_config:
  429. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  430. else:
  431. file_extra_config = None
  432. if not file_extra_config:
  433. return UserPromptMessage(content=message.query)
  434. image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
  435. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  436. file_objs = file_factory.build_from_message_files(
  437. message_files=files, tenant_id=self.tenant_id, config=file_extra_config
  438. )
  439. if not file_objs:
  440. return UserPromptMessage(content=message.query)
  441. prompt_message_contents: list[PromptMessageContentUnionTypes] = []
  442. prompt_message_contents.append(TextPromptMessageContent(data=message.query))
  443. for file in file_objs:
  444. prompt_message_contents.append(
  445. file_manager.to_prompt_message_content(
  446. file,
  447. image_detail_config=image_detail_config,
  448. )
  449. )
  450. return UserPromptMessage(content=prompt_message_contents)