You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_agent_runner.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. import json
  2. import logging
  3. import uuid
  4. from typing import Union, cast
  5. from sqlalchemy import select
  6. from core.agent.entities import AgentEntity, AgentToolEntity
  7. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  8. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  9. from core.app.apps.base_app_queue_manager import AppQueueManager
  10. from core.app.apps.base_app_runner import AppRunner
  11. from core.app.entities.app_invoke_entities import (
  12. AgentChatAppGenerateEntity,
  13. ModelConfigWithCredentialsEntity,
  14. )
  15. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  16. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  17. from core.file import file_manager
  18. from core.memory.token_buffer_memory import TokenBufferMemory
  19. from core.model_manager import ModelInstance
  20. from core.model_runtime.entities import (
  21. AssistantPromptMessage,
  22. LLMUsage,
  23. PromptMessage,
  24. PromptMessageTool,
  25. SystemPromptMessage,
  26. TextPromptMessageContent,
  27. ToolPromptMessage,
  28. UserPromptMessage,
  29. )
  30. from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
  31. from core.model_runtime.entities.model_entities import ModelFeature
  32. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  33. from core.prompt.utils.extract_thread_messages import extract_thread_messages
  34. from core.tools.__base.tool import Tool
  35. from core.tools.entities.tool_entities import (
  36. ToolParameter,
  37. )
  38. from core.tools.tool_manager import ToolManager
  39. from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
  40. from extensions.ext_database import db
  41. from factories import file_factory
  42. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  43. logger = logging.getLogger(__name__)
  44. class BaseAgentRunner(AppRunner):
  45. def __init__(
  46. self,
  47. *,
  48. tenant_id: str,
  49. application_generate_entity: AgentChatAppGenerateEntity,
  50. conversation: Conversation,
  51. app_config: AgentChatAppConfig,
  52. model_config: ModelConfigWithCredentialsEntity,
  53. config: AgentEntity,
  54. queue_manager: AppQueueManager,
  55. message: Message,
  56. user_id: str,
  57. model_instance: ModelInstance,
  58. memory: TokenBufferMemory | None = None,
  59. prompt_messages: list[PromptMessage] | None = None,
  60. ):
  61. self.tenant_id = tenant_id
  62. self.application_generate_entity = application_generate_entity
  63. self.conversation = conversation
  64. self.app_config = app_config
  65. self.model_config = model_config
  66. self.config = config
  67. self.queue_manager = queue_manager
  68. self.message = message
  69. self.user_id = user_id
  70. self.memory = memory
  71. self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
  72. self.model_instance = model_instance
  73. # init callback
  74. self.agent_callback = DifyAgentCallbackHandler()
  75. # init dataset tools
  76. hit_callback = DatasetIndexToolCallbackHandler(
  77. queue_manager=queue_manager,
  78. app_id=self.app_config.app_id,
  79. message_id=message.id,
  80. user_id=user_id,
  81. invoke_from=self.application_generate_entity.invoke_from,
  82. )
  83. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  84. tenant_id=tenant_id,
  85. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  86. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  87. return_resource=app_config.additional_features.show_retrieve_source,
  88. invoke_from=application_generate_entity.invoke_from,
  89. hit_callback=hit_callback,
  90. user_id=user_id,
  91. inputs=cast(dict, application_generate_entity.inputs),
  92. )
  93. # get how many agent thoughts have been created
  94. self.agent_thought_count = (
  95. db.session.query(MessageAgentThought)
  96. .where(
  97. MessageAgentThought.message_id == self.message.id,
  98. )
  99. .count()
  100. )
  101. db.session.close()
  102. # check if model supports stream tool call
  103. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  104. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  105. features = model_schema.features if model_schema and model_schema.features else []
  106. self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
  107. self.files = application_generate_entity.files if ModelFeature.VISION in features else []
  108. self.query: str | None = ""
  109. self._current_thoughts: list[PromptMessage] = []
  110. def _repack_app_generate_entity(
  111. self, app_generate_entity: AgentChatAppGenerateEntity
  112. ) -> AgentChatAppGenerateEntity:
  113. """
  114. Repack app generate entity
  115. """
  116. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  117. app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
  118. return app_generate_entity
  119. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  120. """
  121. convert tool to prompt message tool
  122. """
  123. tool_entity = ToolManager.get_agent_tool_runtime(
  124. tenant_id=self.tenant_id,
  125. app_id=self.app_config.app_id,
  126. agent_tool=tool,
  127. invoke_from=self.application_generate_entity.invoke_from,
  128. )
  129. assert tool_entity.entity.description
  130. message_tool = PromptMessageTool(
  131. name=tool.tool_name,
  132. description=tool_entity.entity.description.llm,
  133. parameters={
  134. "type": "object",
  135. "properties": {},
  136. "required": [],
  137. },
  138. )
  139. parameters = tool_entity.get_merged_runtime_parameters()
  140. for parameter in parameters:
  141. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  142. continue
  143. parameter_type = parameter.type.as_normal_type()
  144. if parameter.type in {
  145. ToolParameter.ToolParameterType.SYSTEM_FILES,
  146. ToolParameter.ToolParameterType.FILE,
  147. ToolParameter.ToolParameterType.FILES,
  148. }:
  149. continue
  150. enum = []
  151. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  152. enum = [option.value for option in parameter.options] if parameter.options else []
  153. message_tool.parameters["properties"][parameter.name] = (
  154. {
  155. "type": parameter_type,
  156. "description": parameter.llm_description or "",
  157. }
  158. if parameter.input_schema is None
  159. else parameter.input_schema
  160. )
  161. if len(enum) > 0:
  162. message_tool.parameters["properties"][parameter.name]["enum"] = enum
  163. if parameter.required:
  164. message_tool.parameters["required"].append(parameter.name)
  165. return message_tool, tool_entity
  166. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  167. """
  168. convert dataset retriever tool to prompt message tool
  169. """
  170. assert tool.entity.description
  171. prompt_tool = PromptMessageTool(
  172. name=tool.entity.identity.name,
  173. description=tool.entity.description.llm,
  174. parameters={
  175. "type": "object",
  176. "properties": {},
  177. "required": [],
  178. },
  179. )
  180. for parameter in tool.get_runtime_parameters():
  181. parameter_type = "string"
  182. prompt_tool.parameters["properties"][parameter.name] = {
  183. "type": parameter_type,
  184. "description": parameter.llm_description or "",
  185. }
  186. if parameter.required:
  187. if parameter.name not in prompt_tool.parameters["required"]:
  188. prompt_tool.parameters["required"].append(parameter.name)
  189. return prompt_tool
  190. def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
  191. """
  192. Init tools
  193. """
  194. tool_instances = {}
  195. prompt_messages_tools = []
  196. for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
  197. try:
  198. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  199. except Exception:
  200. # api tool may be deleted
  201. continue
  202. # save tool entity
  203. tool_instances[tool.tool_name] = tool_entity
  204. # save prompt tool
  205. prompt_messages_tools.append(prompt_tool)
  206. # convert dataset tools into ModelRuntime Tool format
  207. for dataset_tool in self.dataset_tools:
  208. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  209. # save prompt tool
  210. prompt_messages_tools.append(prompt_tool)
  211. # save tool entity
  212. tool_instances[dataset_tool.entity.identity.name] = dataset_tool
  213. return tool_instances, prompt_messages_tools
  214. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  215. """
  216. update prompt message tool
  217. """
  218. # try to get tool runtime parameters
  219. tool_runtime_parameters = tool.get_runtime_parameters()
  220. for parameter in tool_runtime_parameters:
  221. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  222. continue
  223. parameter_type = parameter.type.as_normal_type()
  224. if parameter.type in {
  225. ToolParameter.ToolParameterType.SYSTEM_FILES,
  226. ToolParameter.ToolParameterType.FILE,
  227. ToolParameter.ToolParameterType.FILES,
  228. }:
  229. continue
  230. enum = []
  231. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  232. enum = [option.value for option in parameter.options] if parameter.options else []
  233. prompt_tool.parameters["properties"][parameter.name] = (
  234. {
  235. "type": parameter_type,
  236. "description": parameter.llm_description or "",
  237. }
  238. if parameter.input_schema is None
  239. else parameter.input_schema
  240. )
  241. if len(enum) > 0:
  242. prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
  243. if parameter.required:
  244. if parameter.name not in prompt_tool.parameters["required"]:
  245. prompt_tool.parameters["required"].append(parameter.name)
  246. return prompt_tool
  247. def create_agent_thought(
  248. self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
  249. ) -> str:
  250. """
  251. Create agent thought
  252. """
  253. thought = MessageAgentThought(
  254. message_id=message_id,
  255. message_chain_id=None,
  256. thought="",
  257. tool=tool_name,
  258. tool_labels_str="{}",
  259. tool_meta_str="{}",
  260. tool_input=tool_input,
  261. message=message,
  262. message_token=0,
  263. message_unit_price=0,
  264. message_price_unit=0,
  265. message_files=json.dumps(messages_ids) if messages_ids else "",
  266. answer="",
  267. observation="",
  268. answer_token=0,
  269. answer_unit_price=0,
  270. answer_price_unit=0,
  271. tokens=0,
  272. total_price=0,
  273. position=self.agent_thought_count + 1,
  274. currency="USD",
  275. latency=0,
  276. created_by_role="account",
  277. created_by=self.user_id,
  278. )
  279. db.session.add(thought)
  280. db.session.commit()
  281. agent_thought_id = str(thought.id)
  282. self.agent_thought_count += 1
  283. db.session.close()
  284. return agent_thought_id
  285. def save_agent_thought(
  286. self,
  287. agent_thought_id: str,
  288. tool_name: str | None,
  289. tool_input: Union[str, dict, None],
  290. thought: str | None,
  291. observation: Union[str, dict, None],
  292. tool_invoke_meta: Union[str, dict, None],
  293. answer: str | None,
  294. messages_ids: list[str],
  295. llm_usage: LLMUsage | None = None,
  296. ):
  297. """
  298. Save agent thought
  299. """
  300. stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
  301. agent_thought = db.session.scalar(stmt)
  302. if not agent_thought:
  303. raise ValueError("agent thought not found")
  304. if thought:
  305. agent_thought.thought += thought
  306. if tool_name:
  307. agent_thought.tool = tool_name
  308. if tool_input:
  309. if isinstance(tool_input, dict):
  310. try:
  311. tool_input = json.dumps(tool_input, ensure_ascii=False)
  312. except Exception:
  313. tool_input = json.dumps(tool_input)
  314. agent_thought.tool_input = tool_input
  315. if observation:
  316. if isinstance(observation, dict):
  317. try:
  318. observation = json.dumps(observation, ensure_ascii=False)
  319. except Exception:
  320. observation = json.dumps(observation)
  321. agent_thought.observation = observation
  322. if answer:
  323. agent_thought.answer = answer
  324. if messages_ids is not None and len(messages_ids) > 0:
  325. agent_thought.message_files = json.dumps(messages_ids)
  326. if llm_usage:
  327. agent_thought.message_token = llm_usage.prompt_tokens
  328. agent_thought.message_price_unit = llm_usage.prompt_price_unit
  329. agent_thought.message_unit_price = llm_usage.prompt_unit_price
  330. agent_thought.answer_token = llm_usage.completion_tokens
  331. agent_thought.answer_price_unit = llm_usage.completion_price_unit
  332. agent_thought.answer_unit_price = llm_usage.completion_unit_price
  333. agent_thought.tokens = llm_usage.total_tokens
  334. agent_thought.total_price = llm_usage.total_price
  335. # check if tool labels is not empty
  336. labels = agent_thought.tool_labels or {}
  337. tools = agent_thought.tool.split(";") if agent_thought.tool else []
  338. for tool in tools:
  339. if not tool:
  340. continue
  341. if tool not in labels:
  342. tool_label = ToolManager.get_tool_label(tool)
  343. if tool_label:
  344. labels[tool] = tool_label.to_dict()
  345. else:
  346. labels[tool] = {"en_US": tool, "zh_Hans": tool}
  347. agent_thought.tool_labels_str = json.dumps(labels)
  348. if tool_invoke_meta is not None:
  349. if isinstance(tool_invoke_meta, dict):
  350. try:
  351. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  352. except Exception:
  353. tool_invoke_meta = json.dumps(tool_invoke_meta)
  354. agent_thought.tool_meta_str = tool_invoke_meta
  355. db.session.commit()
  356. db.session.close()
  357. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  358. """
  359. Organize agent history
  360. """
  361. result: list[PromptMessage] = []
  362. # check if there is a system message in the beginning of the conversation
  363. for prompt_message in prompt_messages:
  364. if isinstance(prompt_message, SystemPromptMessage):
  365. result.append(prompt_message)
  366. messages = (
  367. (
  368. db.session.execute(
  369. select(Message)
  370. .where(Message.conversation_id == self.message.conversation_id)
  371. .order_by(Message.created_at.desc())
  372. )
  373. )
  374. .scalars()
  375. .all()
  376. )
  377. messages = list(reversed(extract_thread_messages(messages)))
  378. for message in messages:
  379. if message.id == self.message.id:
  380. continue
  381. result.append(self.organize_agent_user_prompt(message))
  382. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  383. if agent_thoughts:
  384. for agent_thought in agent_thoughts:
  385. tools = agent_thought.tool
  386. if tools:
  387. tools = tools.split(";")
  388. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  389. tool_call_response: list[ToolPromptMessage] = []
  390. try:
  391. tool_inputs = json.loads(agent_thought.tool_input)
  392. except Exception:
  393. tool_inputs = {tool: {} for tool in tools}
  394. try:
  395. tool_responses = json.loads(agent_thought.observation)
  396. except Exception:
  397. tool_responses = dict.fromkeys(tools, agent_thought.observation)
  398. for tool in tools:
  399. # generate a uuid for tool call
  400. tool_call_id = str(uuid.uuid4())
  401. tool_calls.append(
  402. AssistantPromptMessage.ToolCall(
  403. id=tool_call_id,
  404. type="function",
  405. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  406. name=tool,
  407. arguments=json.dumps(tool_inputs.get(tool, {})),
  408. ),
  409. )
  410. )
  411. tool_call_response.append(
  412. ToolPromptMessage(
  413. content=tool_responses.get(tool, agent_thought.observation),
  414. name=tool,
  415. tool_call_id=tool_call_id,
  416. )
  417. )
  418. result.extend(
  419. [
  420. AssistantPromptMessage(
  421. content=agent_thought.thought,
  422. tool_calls=tool_calls,
  423. ),
  424. *tool_call_response,
  425. ]
  426. )
  427. if not tools:
  428. result.append(AssistantPromptMessage(content=agent_thought.thought))
  429. else:
  430. if message.answer:
  431. result.append(AssistantPromptMessage(content=message.answer))
  432. db.session.close()
  433. return result
  434. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  435. stmt = select(MessageFile).where(MessageFile.message_id == message.id)
  436. files = db.session.scalars(stmt).all()
  437. if not files:
  438. return UserPromptMessage(content=message.query)
  439. if message.app_model_config:
  440. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  441. else:
  442. file_extra_config = None
  443. if not file_extra_config:
  444. return UserPromptMessage(content=message.query)
  445. image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
  446. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  447. file_objs = file_factory.build_from_message_files(
  448. message_files=files, tenant_id=self.tenant_id, config=file_extra_config
  449. )
  450. if not file_objs:
  451. return UserPromptMessage(content=message.query)
  452. prompt_message_contents: list[PromptMessageContentUnionTypes] = []
  453. for file in file_objs:
  454. prompt_message_contents.append(
  455. file_manager.to_prompt_message_content(
  456. file,
  457. image_detail_config=image_detail_config,
  458. )
  459. )
  460. prompt_message_contents.append(TextPromptMessageContent(data=message.query))
  461. return UserPromptMessage(content=prompt_message_contents)