You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_agent_runner.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. import json
  2. import logging
  3. import uuid
  4. from typing import Optional, Union, cast
  5. from core.agent.entities import AgentEntity, AgentToolEntity
  6. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  7. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  8. from core.app.apps.base_app_queue_manager import AppQueueManager
  9. from core.app.apps.base_app_runner import AppRunner
  10. from core.app.entities.app_invoke_entities import (
  11. AgentChatAppGenerateEntity,
  12. ModelConfigWithCredentialsEntity,
  13. )
  14. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  15. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  16. from core.file import file_manager
  17. from core.memory.token_buffer_memory import TokenBufferMemory
  18. from core.model_manager import ModelInstance
  19. from core.model_runtime.entities import (
  20. AssistantPromptMessage,
  21. LLMUsage,
  22. PromptMessage,
  23. PromptMessageTool,
  24. SystemPromptMessage,
  25. TextPromptMessageContent,
  26. ToolPromptMessage,
  27. UserPromptMessage,
  28. )
  29. from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
  30. from core.model_runtime.entities.model_entities import ModelFeature
  31. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  32. from core.prompt.utils.extract_thread_messages import extract_thread_messages
  33. from core.tools.__base.tool import Tool
  34. from core.tools.entities.tool_entities import (
  35. ToolParameter,
  36. )
  37. from core.tools.tool_manager import ToolManager
  38. from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
  39. from extensions.ext_database import db
  40. from factories import file_factory
  41. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  42. logger = logging.getLogger(__name__)
  43. class BaseAgentRunner(AppRunner):
  44. def __init__(
  45. self,
  46. *,
  47. tenant_id: str,
  48. application_generate_entity: AgentChatAppGenerateEntity,
  49. conversation: Conversation,
  50. app_config: AgentChatAppConfig,
  51. model_config: ModelConfigWithCredentialsEntity,
  52. config: AgentEntity,
  53. queue_manager: AppQueueManager,
  54. message: Message,
  55. user_id: str,
  56. model_instance: ModelInstance,
  57. memory: Optional[TokenBufferMemory] = None,
  58. prompt_messages: Optional[list[PromptMessage]] = None,
  59. ) -> None:
  60. self.tenant_id = tenant_id
  61. self.application_generate_entity = application_generate_entity
  62. self.conversation = conversation
  63. self.app_config = app_config
  64. self.model_config = model_config
  65. self.config = config
  66. self.queue_manager = queue_manager
  67. self.message = message
  68. self.user_id = user_id
  69. self.memory = memory
  70. self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
  71. self.model_instance = model_instance
  72. # init callback
  73. self.agent_callback = DifyAgentCallbackHandler()
  74. # init dataset tools
  75. hit_callback = DatasetIndexToolCallbackHandler(
  76. queue_manager=queue_manager,
  77. app_id=self.app_config.app_id,
  78. message_id=message.id,
  79. user_id=user_id,
  80. invoke_from=self.application_generate_entity.invoke_from,
  81. )
  82. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  83. tenant_id=tenant_id,
  84. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  85. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  86. return_resource=app_config.additional_features.show_retrieve_source,
  87. invoke_from=application_generate_entity.invoke_from,
  88. hit_callback=hit_callback,
  89. user_id=user_id,
  90. inputs=cast(dict, application_generate_entity.inputs),
  91. )
  92. # get how many agent thoughts have been created
  93. self.agent_thought_count = (
  94. db.session.query(MessageAgentThought)
  95. .filter(
  96. MessageAgentThought.message_id == self.message.id,
  97. )
  98. .count()
  99. )
  100. db.session.close()
  101. # check if model supports stream tool call
  102. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  103. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  104. features = model_schema.features if model_schema and model_schema.features else []
  105. self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
  106. self.files = application_generate_entity.files if ModelFeature.VISION in features else []
  107. self.query: Optional[str] = ""
  108. self._current_thoughts: list[PromptMessage] = []
  109. def _repack_app_generate_entity(
  110. self, app_generate_entity: AgentChatAppGenerateEntity
  111. ) -> AgentChatAppGenerateEntity:
  112. """
  113. Repack app generate entity
  114. """
  115. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  116. app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
  117. return app_generate_entity
  118. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  119. """
  120. convert tool to prompt message tool
  121. """
  122. tool_entity = ToolManager.get_agent_tool_runtime(
  123. tenant_id=self.tenant_id,
  124. app_id=self.app_config.app_id,
  125. agent_tool=tool,
  126. invoke_from=self.application_generate_entity.invoke_from,
  127. )
  128. assert tool_entity.entity.description
  129. message_tool = PromptMessageTool(
  130. name=tool.tool_name,
  131. description=tool_entity.entity.description.llm,
  132. parameters={
  133. "type": "object",
  134. "properties": {},
  135. "required": [],
  136. },
  137. )
  138. parameters = tool_entity.get_merged_runtime_parameters()
  139. for parameter in parameters:
  140. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  141. continue
  142. parameter_type = parameter.type.as_normal_type()
  143. if parameter.type in {
  144. ToolParameter.ToolParameterType.SYSTEM_FILES,
  145. ToolParameter.ToolParameterType.FILE,
  146. ToolParameter.ToolParameterType.FILES,
  147. }:
  148. continue
  149. enum = []
  150. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  151. enum = [option.value for option in parameter.options] if parameter.options else []
  152. message_tool.parameters["properties"][parameter.name] = (
  153. {
  154. "type": parameter_type,
  155. "description": parameter.llm_description or "",
  156. }
  157. if parameter.input_schema is None
  158. else parameter.input_schema
  159. )
  160. if len(enum) > 0:
  161. message_tool.parameters["properties"][parameter.name]["enum"] = enum
  162. if parameter.required:
  163. message_tool.parameters["required"].append(parameter.name)
  164. return message_tool, tool_entity
  165. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  166. """
  167. convert dataset retriever tool to prompt message tool
  168. """
  169. assert tool.entity.description
  170. prompt_tool = PromptMessageTool(
  171. name=tool.entity.identity.name,
  172. description=tool.entity.description.llm,
  173. parameters={
  174. "type": "object",
  175. "properties": {},
  176. "required": [],
  177. },
  178. )
  179. for parameter in tool.get_runtime_parameters():
  180. parameter_type = "string"
  181. prompt_tool.parameters["properties"][parameter.name] = {
  182. "type": parameter_type,
  183. "description": parameter.llm_description or "",
  184. }
  185. if parameter.required:
  186. if parameter.name not in prompt_tool.parameters["required"]:
  187. prompt_tool.parameters["required"].append(parameter.name)
  188. return prompt_tool
  189. def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
  190. """
  191. Init tools
  192. """
  193. tool_instances = {}
  194. prompt_messages_tools = []
  195. for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
  196. try:
  197. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  198. except Exception:
  199. # api tool may be deleted
  200. continue
  201. # save tool entity
  202. tool_instances[tool.tool_name] = tool_entity
  203. # save prompt tool
  204. prompt_messages_tools.append(prompt_tool)
  205. # convert dataset tools into ModelRuntime Tool format
  206. for dataset_tool in self.dataset_tools:
  207. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  208. # save prompt tool
  209. prompt_messages_tools.append(prompt_tool)
  210. # save tool entity
  211. tool_instances[dataset_tool.entity.identity.name] = dataset_tool
  212. return tool_instances, prompt_messages_tools
  213. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  214. """
  215. update prompt message tool
  216. """
  217. # try to get tool runtime parameters
  218. tool_runtime_parameters = tool.get_runtime_parameters()
  219. for parameter in tool_runtime_parameters:
  220. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  221. continue
  222. parameter_type = parameter.type.as_normal_type()
  223. if parameter.type in {
  224. ToolParameter.ToolParameterType.SYSTEM_FILES,
  225. ToolParameter.ToolParameterType.FILE,
  226. ToolParameter.ToolParameterType.FILES,
  227. }:
  228. continue
  229. enum = []
  230. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  231. enum = [option.value for option in parameter.options] if parameter.options else []
  232. prompt_tool.parameters["properties"][parameter.name] = (
  233. {
  234. "type": parameter_type,
  235. "description": parameter.llm_description or "",
  236. }
  237. if parameter.input_schema is None
  238. else parameter.input_schema
  239. )
  240. if len(enum) > 0:
  241. prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
  242. if parameter.required:
  243. if parameter.name not in prompt_tool.parameters["required"]:
  244. prompt_tool.parameters["required"].append(parameter.name)
  245. return prompt_tool
  246. def create_agent_thought(
  247. self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
  248. ) -> MessageAgentThought:
  249. """
  250. Create agent thought
  251. """
  252. thought = MessageAgentThought(
  253. message_id=message_id,
  254. message_chain_id=None,
  255. thought="",
  256. tool=tool_name,
  257. tool_labels_str="{}",
  258. tool_meta_str="{}",
  259. tool_input=tool_input,
  260. message=message,
  261. message_token=0,
  262. message_unit_price=0,
  263. message_price_unit=0,
  264. message_files=json.dumps(messages_ids) if messages_ids else "",
  265. answer="",
  266. observation="",
  267. answer_token=0,
  268. answer_unit_price=0,
  269. answer_price_unit=0,
  270. tokens=0,
  271. total_price=0,
  272. position=self.agent_thought_count + 1,
  273. currency="USD",
  274. latency=0,
  275. created_by_role="account",
  276. created_by=self.user_id,
  277. )
  278. db.session.add(thought)
  279. db.session.commit()
  280. db.session.refresh(thought)
  281. db.session.close()
  282. self.agent_thought_count += 1
  283. return thought
  284. def save_agent_thought(
  285. self,
  286. agent_thought: MessageAgentThought,
  287. tool_name: str | None,
  288. tool_input: Union[str, dict, None],
  289. thought: str | None,
  290. observation: Union[str, dict, None],
  291. tool_invoke_meta: Union[str, dict, None],
  292. answer: str | None,
  293. messages_ids: list[str],
  294. llm_usage: LLMUsage | None = None,
  295. ):
  296. """
  297. Save agent thought
  298. """
  299. updated_agent_thought = (
  300. db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
  301. )
  302. if not updated_agent_thought:
  303. raise ValueError("agent thought not found")
  304. agent_thought = updated_agent_thought
  305. if thought:
  306. agent_thought.thought += thought
  307. if tool_name:
  308. agent_thought.tool = tool_name
  309. if tool_input:
  310. if isinstance(tool_input, dict):
  311. try:
  312. tool_input = json.dumps(tool_input, ensure_ascii=False)
  313. except Exception:
  314. tool_input = json.dumps(tool_input)
  315. updated_agent_thought.tool_input = tool_input
  316. if observation:
  317. if isinstance(observation, dict):
  318. try:
  319. observation = json.dumps(observation, ensure_ascii=False)
  320. except Exception:
  321. observation = json.dumps(observation)
  322. updated_agent_thought.observation = observation
  323. if answer:
  324. agent_thought.answer = answer
  325. if messages_ids is not None and len(messages_ids) > 0:
  326. updated_agent_thought.message_files = json.dumps(messages_ids)
  327. if llm_usage:
  328. updated_agent_thought.message_token = llm_usage.prompt_tokens
  329. updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
  330. updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
  331. updated_agent_thought.answer_token = llm_usage.completion_tokens
  332. updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
  333. updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
  334. updated_agent_thought.tokens = llm_usage.total_tokens
  335. updated_agent_thought.total_price = llm_usage.total_price
  336. # check if tool labels is not empty
  337. labels = updated_agent_thought.tool_labels or {}
  338. tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
  339. for tool in tools:
  340. if not tool:
  341. continue
  342. if tool not in labels:
  343. tool_label = ToolManager.get_tool_label(tool)
  344. if tool_label:
  345. labels[tool] = tool_label.to_dict()
  346. else:
  347. labels[tool] = {"en_US": tool, "zh_Hans": tool}
  348. updated_agent_thought.tool_labels_str = json.dumps(labels)
  349. if tool_invoke_meta is not None:
  350. if isinstance(tool_invoke_meta, dict):
  351. try:
  352. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  353. except Exception:
  354. tool_invoke_meta = json.dumps(tool_invoke_meta)
  355. updated_agent_thought.tool_meta_str = tool_invoke_meta
  356. db.session.commit()
  357. db.session.close()
  358. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  359. """
  360. Organize agent history
  361. """
  362. result: list[PromptMessage] = []
  363. # check if there is a system message in the beginning of the conversation
  364. for prompt_message in prompt_messages:
  365. if isinstance(prompt_message, SystemPromptMessage):
  366. result.append(prompt_message)
  367. messages: list[Message] = (
  368. db.session.query(Message)
  369. .filter(
  370. Message.conversation_id == self.message.conversation_id,
  371. )
  372. .order_by(Message.created_at.desc())
  373. .all()
  374. )
  375. messages = list(reversed(extract_thread_messages(messages)))
  376. for message in messages:
  377. if message.id == self.message.id:
  378. continue
  379. result.append(self.organize_agent_user_prompt(message))
  380. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  381. if agent_thoughts:
  382. for agent_thought in agent_thoughts:
  383. tools = agent_thought.tool
  384. if tools:
  385. tools = tools.split(";")
  386. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  387. tool_call_response: list[ToolPromptMessage] = []
  388. try:
  389. tool_inputs = json.loads(agent_thought.tool_input)
  390. except Exception:
  391. tool_inputs = {tool: {} for tool in tools}
  392. try:
  393. tool_responses = json.loads(agent_thought.observation)
  394. except Exception:
  395. tool_responses = dict.fromkeys(tools, agent_thought.observation)
  396. for tool in tools:
  397. # generate a uuid for tool call
  398. tool_call_id = str(uuid.uuid4())
  399. tool_calls.append(
  400. AssistantPromptMessage.ToolCall(
  401. id=tool_call_id,
  402. type="function",
  403. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  404. name=tool,
  405. arguments=json.dumps(tool_inputs.get(tool, {})),
  406. ),
  407. )
  408. )
  409. tool_call_response.append(
  410. ToolPromptMessage(
  411. content=tool_responses.get(tool, agent_thought.observation),
  412. name=tool,
  413. tool_call_id=tool_call_id,
  414. )
  415. )
  416. result.extend(
  417. [
  418. AssistantPromptMessage(
  419. content=agent_thought.thought,
  420. tool_calls=tool_calls,
  421. ),
  422. *tool_call_response,
  423. ]
  424. )
  425. if not tools:
  426. result.append(AssistantPromptMessage(content=agent_thought.thought))
  427. else:
  428. if message.answer:
  429. result.append(AssistantPromptMessage(content=message.answer))
  430. db.session.close()
  431. return result
  432. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  433. files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
  434. if not files:
  435. return UserPromptMessage(content=message.query)
  436. if message.app_model_config:
  437. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  438. else:
  439. file_extra_config = None
  440. if not file_extra_config:
  441. return UserPromptMessage(content=message.query)
  442. image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
  443. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  444. file_objs = file_factory.build_from_message_files(
  445. message_files=files, tenant_id=self.tenant_id, config=file_extra_config
  446. )
  447. if not file_objs:
  448. return UserPromptMessage(content=message.query)
  449. prompt_message_contents: list[PromptMessageContentUnionTypes] = []
  450. prompt_message_contents.append(TextPromptMessageContent(data=message.query))
  451. for file in file_objs:
  452. prompt_message_contents.append(
  453. file_manager.to_prompt_message_content(
  454. file,
  455. image_detail_config=image_detail_config,
  456. )
  457. )
  458. return UserPromptMessage(content=prompt_message_contents)