Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

base_agent_runner.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. import json
  2. import logging
  3. import uuid
  4. from typing import Optional, Union, cast
  5. from sqlalchemy import select
  6. from core.agent.entities import AgentEntity, AgentToolEntity
  7. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  8. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  9. from core.app.apps.base_app_queue_manager import AppQueueManager
  10. from core.app.apps.base_app_runner import AppRunner
  11. from core.app.entities.app_invoke_entities import (
  12. AgentChatAppGenerateEntity,
  13. ModelConfigWithCredentialsEntity,
  14. )
  15. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  16. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  17. from core.file import file_manager
  18. from core.memory.token_buffer_memory import TokenBufferMemory
  19. from core.model_manager import ModelInstance
  20. from core.model_runtime.entities import (
  21. AssistantPromptMessage,
  22. LLMUsage,
  23. PromptMessage,
  24. PromptMessageTool,
  25. SystemPromptMessage,
  26. TextPromptMessageContent,
  27. ToolPromptMessage,
  28. UserPromptMessage,
  29. )
  30. from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
  31. from core.model_runtime.entities.model_entities import ModelFeature
  32. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  33. from core.prompt.utils.extract_thread_messages import extract_thread_messages
  34. from core.tools.__base.tool import Tool
  35. from core.tools.entities.tool_entities import (
  36. ToolParameter,
  37. )
  38. from core.tools.tool_manager import ToolManager
  39. from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
  40. from extensions.ext_database import db
  41. from factories import file_factory
  42. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  43. logger = logging.getLogger(__name__)
  44. class BaseAgentRunner(AppRunner):
  45. def __init__(
  46. self,
  47. *,
  48. tenant_id: str,
  49. application_generate_entity: AgentChatAppGenerateEntity,
  50. conversation: Conversation,
  51. app_config: AgentChatAppConfig,
  52. model_config: ModelConfigWithCredentialsEntity,
  53. config: AgentEntity,
  54. queue_manager: AppQueueManager,
  55. message: Message,
  56. user_id: str,
  57. model_instance: ModelInstance,
  58. memory: Optional[TokenBufferMemory] = None,
  59. prompt_messages: Optional[list[PromptMessage]] = None,
  60. ) -> None:
  61. self.tenant_id = tenant_id
  62. self.application_generate_entity = application_generate_entity
  63. self.conversation = conversation
  64. self.app_config = app_config
  65. self.model_config = model_config
  66. self.config = config
  67. self.queue_manager = queue_manager
  68. self.message = message
  69. self.user_id = user_id
  70. self.memory = memory
  71. self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
  72. self.model_instance = model_instance
  73. # init callback
  74. self.agent_callback = DifyAgentCallbackHandler()
  75. # init dataset tools
  76. hit_callback = DatasetIndexToolCallbackHandler(
  77. queue_manager=queue_manager,
  78. app_id=self.app_config.app_id,
  79. message_id=message.id,
  80. user_id=user_id,
  81. invoke_from=self.application_generate_entity.invoke_from,
  82. )
  83. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  84. tenant_id=tenant_id,
  85. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  86. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  87. return_resource=app_config.additional_features.show_retrieve_source,
  88. invoke_from=application_generate_entity.invoke_from,
  89. hit_callback=hit_callback,
  90. user_id=user_id,
  91. inputs=cast(dict, application_generate_entity.inputs),
  92. )
  93. # get how many agent thoughts have been created
  94. self.agent_thought_count = (
  95. db.session.query(MessageAgentThought)
  96. .filter(
  97. MessageAgentThought.message_id == self.message.id,
  98. )
  99. .count()
  100. )
  101. db.session.close()
  102. # check if model supports stream tool call
  103. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  104. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  105. features = model_schema.features if model_schema and model_schema.features else []
  106. self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
  107. self.files = application_generate_entity.files if ModelFeature.VISION in features else []
  108. self.query: Optional[str] = ""
  109. self._current_thoughts: list[PromptMessage] = []
  110. def _repack_app_generate_entity(
  111. self, app_generate_entity: AgentChatAppGenerateEntity
  112. ) -> AgentChatAppGenerateEntity:
  113. """
  114. Repack app generate entity
  115. """
  116. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  117. app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
  118. return app_generate_entity
  119. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  120. """
  121. convert tool to prompt message tool
  122. """
  123. tool_entity = ToolManager.get_agent_tool_runtime(
  124. tenant_id=self.tenant_id,
  125. app_id=self.app_config.app_id,
  126. agent_tool=tool,
  127. invoke_from=self.application_generate_entity.invoke_from,
  128. )
  129. assert tool_entity.entity.description
  130. message_tool = PromptMessageTool(
  131. name=tool.tool_name,
  132. description=tool_entity.entity.description.llm,
  133. parameters={
  134. "type": "object",
  135. "properties": {},
  136. "required": [],
  137. },
  138. )
  139. parameters = tool_entity.get_merged_runtime_parameters()
  140. for parameter in parameters:
  141. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  142. continue
  143. parameter_type = parameter.type.as_normal_type()
  144. if parameter.type in {
  145. ToolParameter.ToolParameterType.SYSTEM_FILES,
  146. ToolParameter.ToolParameterType.FILE,
  147. ToolParameter.ToolParameterType.FILES,
  148. }:
  149. continue
  150. enum = []
  151. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  152. enum = [option.value for option in parameter.options] if parameter.options else []
  153. message_tool.parameters["properties"][parameter.name] = (
  154. {
  155. "type": parameter_type,
  156. "description": parameter.llm_description or "",
  157. }
  158. if parameter.input_schema is None
  159. else parameter.input_schema
  160. )
  161. if len(enum) > 0:
  162. message_tool.parameters["properties"][parameter.name]["enum"] = enum
  163. if parameter.required:
  164. message_tool.parameters["required"].append(parameter.name)
  165. return message_tool, tool_entity
  166. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  167. """
  168. convert dataset retriever tool to prompt message tool
  169. """
  170. assert tool.entity.description
  171. prompt_tool = PromptMessageTool(
  172. name=tool.entity.identity.name,
  173. description=tool.entity.description.llm,
  174. parameters={
  175. "type": "object",
  176. "properties": {},
  177. "required": [],
  178. },
  179. )
  180. for parameter in tool.get_runtime_parameters():
  181. parameter_type = "string"
  182. prompt_tool.parameters["properties"][parameter.name] = {
  183. "type": parameter_type,
  184. "description": parameter.llm_description or "",
  185. }
  186. if parameter.required:
  187. if parameter.name not in prompt_tool.parameters["required"]:
  188. prompt_tool.parameters["required"].append(parameter.name)
  189. return prompt_tool
  190. def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
  191. """
  192. Init tools
  193. """
  194. tool_instances = {}
  195. prompt_messages_tools = []
  196. for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
  197. try:
  198. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  199. except Exception:
  200. # api tool may be deleted
  201. continue
  202. # save tool entity
  203. tool_instances[tool.tool_name] = tool_entity
  204. # save prompt tool
  205. prompt_messages_tools.append(prompt_tool)
  206. # convert dataset tools into ModelRuntime Tool format
  207. for dataset_tool in self.dataset_tools:
  208. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  209. # save prompt tool
  210. prompt_messages_tools.append(prompt_tool)
  211. # save tool entity
  212. tool_instances[dataset_tool.entity.identity.name] = dataset_tool
  213. return tool_instances, prompt_messages_tools
  214. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  215. """
  216. update prompt message tool
  217. """
  218. # try to get tool runtime parameters
  219. tool_runtime_parameters = tool.get_runtime_parameters()
  220. for parameter in tool_runtime_parameters:
  221. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  222. continue
  223. parameter_type = parameter.type.as_normal_type()
  224. if parameter.type in {
  225. ToolParameter.ToolParameterType.SYSTEM_FILES,
  226. ToolParameter.ToolParameterType.FILE,
  227. ToolParameter.ToolParameterType.FILES,
  228. }:
  229. continue
  230. enum = []
  231. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  232. enum = [option.value for option in parameter.options] if parameter.options else []
  233. prompt_tool.parameters["properties"][parameter.name] = (
  234. {
  235. "type": parameter_type,
  236. "description": parameter.llm_description or "",
  237. }
  238. if parameter.input_schema is None
  239. else parameter.input_schema
  240. )
  241. if len(enum) > 0:
  242. prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
  243. if parameter.required:
  244. if parameter.name not in prompt_tool.parameters["required"]:
  245. prompt_tool.parameters["required"].append(parameter.name)
  246. return prompt_tool
  247. def create_agent_thought(
  248. self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
  249. ) -> MessageAgentThought:
  250. """
  251. Create agent thought
  252. """
  253. thought = MessageAgentThought(
  254. message_id=message_id,
  255. message_chain_id=None,
  256. thought="",
  257. tool=tool_name,
  258. tool_labels_str="{}",
  259. tool_meta_str="{}",
  260. tool_input=tool_input,
  261. message=message,
  262. message_token=0,
  263. message_unit_price=0,
  264. message_price_unit=0,
  265. message_files=json.dumps(messages_ids) if messages_ids else "",
  266. answer="",
  267. observation="",
  268. answer_token=0,
  269. answer_unit_price=0,
  270. answer_price_unit=0,
  271. tokens=0,
  272. total_price=0,
  273. position=self.agent_thought_count + 1,
  274. currency="USD",
  275. latency=0,
  276. created_by_role="account",
  277. created_by=self.user_id,
  278. )
  279. db.session.add(thought)
  280. db.session.commit()
  281. db.session.refresh(thought)
  282. db.session.close()
  283. self.agent_thought_count += 1
  284. return thought
  285. def save_agent_thought(
  286. self,
  287. agent_thought: MessageAgentThought,
  288. tool_name: str | None,
  289. tool_input: Union[str, dict, None],
  290. thought: str | None,
  291. observation: Union[str, dict, None],
  292. tool_invoke_meta: Union[str, dict, None],
  293. answer: str | None,
  294. messages_ids: list[str],
  295. llm_usage: LLMUsage | None = None,
  296. ):
  297. """
  298. Save agent thought
  299. """
  300. updated_agent_thought = (
  301. db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
  302. )
  303. if not updated_agent_thought:
  304. raise ValueError("agent thought not found")
  305. agent_thought = updated_agent_thought
  306. if thought:
  307. agent_thought.thought += thought
  308. if tool_name:
  309. agent_thought.tool = tool_name
  310. if tool_input:
  311. if isinstance(tool_input, dict):
  312. try:
  313. tool_input = json.dumps(tool_input, ensure_ascii=False)
  314. except Exception:
  315. tool_input = json.dumps(tool_input)
  316. updated_agent_thought.tool_input = tool_input
  317. if observation:
  318. if isinstance(observation, dict):
  319. try:
  320. observation = json.dumps(observation, ensure_ascii=False)
  321. except Exception:
  322. observation = json.dumps(observation)
  323. updated_agent_thought.observation = observation
  324. if answer:
  325. agent_thought.answer = answer
  326. if messages_ids is not None and len(messages_ids) > 0:
  327. updated_agent_thought.message_files = json.dumps(messages_ids)
  328. if llm_usage:
  329. updated_agent_thought.message_token = llm_usage.prompt_tokens
  330. updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
  331. updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
  332. updated_agent_thought.answer_token = llm_usage.completion_tokens
  333. updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
  334. updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
  335. updated_agent_thought.tokens = llm_usage.total_tokens
  336. updated_agent_thought.total_price = llm_usage.total_price
  337. # check if tool labels is not empty
  338. labels = updated_agent_thought.tool_labels or {}
  339. tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
  340. for tool in tools:
  341. if not tool:
  342. continue
  343. if tool not in labels:
  344. tool_label = ToolManager.get_tool_label(tool)
  345. if tool_label:
  346. labels[tool] = tool_label.to_dict()
  347. else:
  348. labels[tool] = {"en_US": tool, "zh_Hans": tool}
  349. updated_agent_thought.tool_labels_str = json.dumps(labels)
  350. if tool_invoke_meta is not None:
  351. if isinstance(tool_invoke_meta, dict):
  352. try:
  353. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  354. except Exception:
  355. tool_invoke_meta = json.dumps(tool_invoke_meta)
  356. updated_agent_thought.tool_meta_str = tool_invoke_meta
  357. db.session.commit()
  358. db.session.close()
  359. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  360. """
  361. Organize agent history
  362. """
  363. result: list[PromptMessage] = []
  364. # check if there is a system message in the beginning of the conversation
  365. for prompt_message in prompt_messages:
  366. if isinstance(prompt_message, SystemPromptMessage):
  367. result.append(prompt_message)
  368. messages = (
  369. (
  370. db.session.execute(
  371. select(Message)
  372. .where(Message.conversation_id == self.message.conversation_id)
  373. .order_by(Message.created_at.desc())
  374. )
  375. )
  376. .scalars()
  377. .all()
  378. )
  379. messages = list(reversed(extract_thread_messages(messages)))
  380. for message in messages:
  381. if message.id == self.message.id:
  382. continue
  383. result.append(self.organize_agent_user_prompt(message))
  384. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  385. if agent_thoughts:
  386. for agent_thought in agent_thoughts:
  387. tools = agent_thought.tool
  388. if tools:
  389. tools = tools.split(";")
  390. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  391. tool_call_response: list[ToolPromptMessage] = []
  392. try:
  393. tool_inputs = json.loads(agent_thought.tool_input)
  394. except Exception:
  395. tool_inputs = {tool: {} for tool in tools}
  396. try:
  397. tool_responses = json.loads(agent_thought.observation)
  398. except Exception:
  399. tool_responses = dict.fromkeys(tools, agent_thought.observation)
  400. for tool in tools:
  401. # generate a uuid for tool call
  402. tool_call_id = str(uuid.uuid4())
  403. tool_calls.append(
  404. AssistantPromptMessage.ToolCall(
  405. id=tool_call_id,
  406. type="function",
  407. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  408. name=tool,
  409. arguments=json.dumps(tool_inputs.get(tool, {})),
  410. ),
  411. )
  412. )
  413. tool_call_response.append(
  414. ToolPromptMessage(
  415. content=tool_responses.get(tool, agent_thought.observation),
  416. name=tool,
  417. tool_call_id=tool_call_id,
  418. )
  419. )
  420. result.extend(
  421. [
  422. AssistantPromptMessage(
  423. content=agent_thought.thought,
  424. tool_calls=tool_calls,
  425. ),
  426. *tool_call_response,
  427. ]
  428. )
  429. if not tools:
  430. result.append(AssistantPromptMessage(content=agent_thought.thought))
  431. else:
  432. if message.answer:
  433. result.append(AssistantPromptMessage(content=message.answer))
  434. db.session.close()
  435. return result
  436. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  437. files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
  438. if not files:
  439. return UserPromptMessage(content=message.query)
  440. if message.app_model_config:
  441. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  442. else:
  443. file_extra_config = None
  444. if not file_extra_config:
  445. return UserPromptMessage(content=message.query)
  446. image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
  447. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  448. file_objs = file_factory.build_from_message_files(
  449. message_files=files, tenant_id=self.tenant_id, config=file_extra_config
  450. )
  451. if not file_objs:
  452. return UserPromptMessage(content=message.query)
  453. prompt_message_contents: list[PromptMessageContentUnionTypes] = []
  454. prompt_message_contents.append(TextPromptMessageContent(data=message.query))
  455. for file in file_objs:
  456. prompt_message_contents.append(
  457. file_manager.to_prompt_message_content(
  458. file,
  459. image_detail_config=image_detail_config,
  460. )
  461. )
  462. return UserPromptMessage(content=prompt_message_contents)