Co-authored-by: Joel <iamjoel007@gmail.com>tags/0.6.2
| from typing import Optional, Union, cast | from typing import Optional, Union, cast | ||||
| from core.agent.entities import AgentEntity, AgentToolEntity | from core.agent.entities import AgentEntity, AgentToolEntity | ||||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||||
| from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig | from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig | ||||
| from core.app.apps.base_app_queue_manager import AppQueueManager | from core.app.apps.base_app_queue_manager import AppQueueManager | ||||
| from core.app.apps.base_app_runner import AppRunner | from core.app.apps.base_app_runner import AppRunner | ||||
| ) | ) | ||||
| from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler | from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler | ||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | ||||
| from core.file.message_file_parser import MessageFileParser | |||||
| from core.memory.token_buffer_memory import TokenBufferMemory | from core.memory.token_buffer_memory import TokenBufferMemory | ||||
| from core.model_manager import ModelInstance | from core.model_manager import ModelInstance | ||||
| from core.model_runtime.entities.llm_entities import LLMUsage | from core.model_runtime.entities.llm_entities import LLMUsage | ||||
| PromptMessage, | PromptMessage, | ||||
| PromptMessageTool, | PromptMessageTool, | ||||
| SystemPromptMessage, | SystemPromptMessage, | ||||
| TextPromptMessageContent, | |||||
| ToolPromptMessage, | ToolPromptMessage, | ||||
| UserPromptMessage, | UserPromptMessage, | ||||
| ) | ) | ||||
| from core.tools.tool.tool import Tool | from core.tools.tool.tool import Tool | ||||
| from core.tools.tool_manager import ToolManager | from core.tools.tool_manager import ToolManager | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import Message, MessageAgentThought | |||||
| from models.model import Conversation, Message, MessageAgentThought | |||||
| from models.tools import ToolConversationVariables | from models.tools import ToolConversationVariables | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class BaseAgentRunner(AppRunner): | class BaseAgentRunner(AppRunner): | ||||
| def __init__(self, tenant_id: str, | def __init__(self, tenant_id: str, | ||||
| application_generate_entity: AgentChatAppGenerateEntity, | application_generate_entity: AgentChatAppGenerateEntity, | ||||
| conversation: Conversation, | |||||
| app_config: AgentChatAppConfig, | app_config: AgentChatAppConfig, | ||||
| model_config: ModelConfigWithCredentialsEntity, | model_config: ModelConfigWithCredentialsEntity, | ||||
| config: AgentEntity, | config: AgentEntity, | ||||
| """ | """ | ||||
| self.tenant_id = tenant_id | self.tenant_id = tenant_id | ||||
| self.application_generate_entity = application_generate_entity | self.application_generate_entity = application_generate_entity | ||||
| self.conversation = conversation | |||||
| self.app_config = app_config | self.app_config = app_config | ||||
| self.model_config = model_config | self.model_config = model_config | ||||
| self.config = config | self.config = config | ||||
| else: | else: | ||||
| self.stream_tool_call = False | self.stream_tool_call = False | ||||
| # check if model supports vision | |||||
| if model_schema and ModelFeature.VISION in (model_schema.features or []): | |||||
| self.files = application_generate_entity.files | |||||
| else: | |||||
| self.files = [] | |||||
| def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ | def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ | ||||
| -> AgentChatAppGenerateEntity: | -> AgentChatAppGenerateEntity: | ||||
| """ | """ | ||||
| """ | """ | ||||
| result = [] | result = [] | ||||
| # check if there is a system message in the beginning of the conversation | # check if there is a system message in the beginning of the conversation | ||||
| if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage): | |||||
| result.append(prompt_messages[0]) | |||||
| for prompt_message in prompt_messages: | |||||
| if isinstance(prompt_message, SystemPromptMessage): | |||||
| result.append(prompt_message) | |||||
| messages: list[Message] = db.session.query(Message).filter( | messages: list[Message] = db.session.query(Message).filter( | ||||
| Message.conversation_id == self.message.conversation_id, | Message.conversation_id == self.message.conversation_id, | ||||
| ).order_by(Message.created_at.asc()).all() | ).order_by(Message.created_at.asc()).all() | ||||
| for message in messages: | for message in messages: | ||||
| result.append(UserPromptMessage(content=message.query)) | |||||
| if message.id == self.message.id: | |||||
| continue | |||||
| result.append(self.organize_agent_user_prompt(message)) | |||||
| agent_thoughts: list[MessageAgentThought] = message.agent_thoughts | agent_thoughts: list[MessageAgentThought] = message.agent_thoughts | ||||
| if agent_thoughts: | if agent_thoughts: | ||||
| for agent_thought in agent_thoughts: | for agent_thought in agent_thoughts: | ||||
| db.session.close() | db.session.close() | ||||
| return result | return result | ||||
| def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: | |||||
| message_file_parser = MessageFileParser( | |||||
| tenant_id=self.tenant_id, | |||||
| app_id=self.app_config.app_id, | |||||
| ) | |||||
| files = message.message_files | |||||
| if files: | |||||
| file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) | |||||
| if file_extra_config: | |||||
| file_objs = message_file_parser.transform_message_files( | |||||
| files, | |||||
| file_extra_config | |||||
| ) | |||||
| else: | |||||
| file_objs = [] | |||||
| if not file_objs: | |||||
| return UserPromptMessage(content=message.query) | |||||
| else: | |||||
| prompt_message_contents = [TextPromptMessageContent(data=message.query)] | |||||
| for file_obj in file_objs: | |||||
| prompt_message_contents.append(file_obj.prompt_message_content) | |||||
| return UserPromptMessage(content=prompt_message_contents) | |||||
| else: | |||||
| return UserPromptMessage(content=message.query) | 
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from core.tools.entities.tool_entities import ToolInvokeMeta | from core.tools.entities.tool_entities import ToolInvokeMeta | ||||
| from core.tools.tool_engine import ToolEngine | from core.tools.tool_engine import ToolEngine | ||||
| from models.model import Conversation, Message | |||||
| from models.model import Message | |||||
| class CotAgentRunner(BaseAgentRunner): | class CotAgentRunner(BaseAgentRunner): | ||||
| _is_first_iteration = True | _is_first_iteration = True | ||||
| _ignore_observation_providers = ['wenxin'] | _ignore_observation_providers = ['wenxin'] | ||||
| def run(self, conversation: Conversation, | |||||
| message: Message, | |||||
| def run(self, message: Message, | |||||
| query: str, | query: str, | ||||
| inputs: dict[str, str], | inputs: dict[str, str], | ||||
| ) -> Union[Generator, LLMResult]: | ) -> Union[Generator, LLMResult]: | 
| import json | import json | ||||
| import logging | import logging | ||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from copy import deepcopy | |||||
| from typing import Any, Union | from typing import Any, Union | ||||
| from core.agent.base_agent_runner import BaseAgentRunner | from core.agent.base_agent_runner import BaseAgentRunner | ||||
| from core.model_runtime.entities.message_entities import ( | from core.model_runtime.entities.message_entities import ( | ||||
| AssistantPromptMessage, | AssistantPromptMessage, | ||||
| PromptMessage, | PromptMessage, | ||||
| PromptMessageContentType, | |||||
| PromptMessageTool, | PromptMessageTool, | ||||
| SystemPromptMessage, | SystemPromptMessage, | ||||
| TextPromptMessageContent, | |||||
| ToolPromptMessage, | ToolPromptMessage, | ||||
| UserPromptMessage, | UserPromptMessage, | ||||
| ) | ) | ||||
| from core.tools.entities.tool_entities import ToolInvokeMeta | from core.tools.entities.tool_entities import ToolInvokeMeta | ||||
| from core.tools.tool_engine import ToolEngine | from core.tools.tool_engine import ToolEngine | ||||
| from models.model import Conversation, Message, MessageAgentThought | |||||
| from models.model import Message | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class FunctionCallAgentRunner(BaseAgentRunner): | class FunctionCallAgentRunner(BaseAgentRunner): | ||||
| def run(self, conversation: Conversation, | |||||
| message: Message, | |||||
| def run(self, message: Message, | |||||
| query: str, | query: str, | ||||
| ) -> Generator[LLMResultChunk, None, None]: | ) -> Generator[LLMResultChunk, None, None]: | ||||
| """ | """ | ||||
| prompt_template = app_config.prompt_template.simple_prompt_template or '' | prompt_template = app_config.prompt_template.simple_prompt_template or '' | ||||
| prompt_messages = self.history_prompt_messages | prompt_messages = self.history_prompt_messages | ||||
| prompt_messages = self.organize_prompt_messages( | |||||
| prompt_template=prompt_template, | |||||
| query=query, | |||||
| prompt_messages=prompt_messages | |||||
| ) | |||||
| prompt_messages = self._init_system_message(prompt_template, prompt_messages) | |||||
| prompt_messages = self._organize_user_query(query, prompt_messages) | |||||
| # convert tools into ModelRuntime Tool format | # convert tools into ModelRuntime Tool format | ||||
| prompt_messages_tools: list[PromptMessageTool] = [] | prompt_messages_tools: list[PromptMessageTool] = [] | ||||
| # continue to run until there is not any tool call | # continue to run until there is not any tool call | ||||
| function_call_state = True | function_call_state = True | ||||
| agent_thoughts: list[MessageAgentThought] = [] | |||||
| llm_usage = { | llm_usage = { | ||||
| 'usage': None | 'usage': None | ||||
| } | } | ||||
| } | } | ||||
| tool_responses.append(tool_response) | tool_responses.append(tool_response) | ||||
| prompt_messages = self.organize_prompt_messages( | |||||
| prompt_template=prompt_template, | |||||
| query=None, | |||||
| prompt_messages = self._organize_assistant_message( | |||||
| tool_call_id=tool_call_id, | tool_call_id=tool_call_id, | ||||
| tool_call_name=tool_call_name, | tool_call_name=tool_call_name, | ||||
| tool_response=tool_response['tool_response'], | tool_response=tool_response['tool_response'], | ||||
| iteration_step += 1 | iteration_step += 1 | ||||
| prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) | |||||
| self.update_db_variables(self.variables_pool, self.db_variables_pool) | self.update_db_variables(self.variables_pool, self.db_variables_pool) | ||||
| # publish end event | # publish end event | ||||
| self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( | self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( | ||||
| return tool_calls | return tool_calls | ||||
| def organize_prompt_messages(self, prompt_template: str, | |||||
| query: str = None, | |||||
| tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, | |||||
| prompt_messages: list[PromptMessage] = None | |||||
| ) -> list[PromptMessage]: | |||||
| def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: | |||||
| """ | """ | ||||
| Organize prompt messages | |||||
| Initialize system message | |||||
| """ | """ | ||||
| if not prompt_messages: | |||||
| prompt_messages = [ | |||||
| if not prompt_messages and prompt_template: | |||||
| return [ | |||||
| SystemPromptMessage(content=prompt_template), | SystemPromptMessage(content=prompt_template), | ||||
| UserPromptMessage(content=query), | |||||
| ] | ] | ||||
| if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: | |||||
| prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) | |||||
| return prompt_messages | |||||
| def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: | |||||
| """ | |||||
| Organize user query | |||||
| """ | |||||
| if self.files: | |||||
| prompt_message_contents = [TextPromptMessageContent(data=query)] | |||||
| for file_obj in self.files: | |||||
| prompt_message_contents.append(file_obj.prompt_message_content) | |||||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | |||||
| else: | else: | ||||
| if tool_response: | |||||
| prompt_messages = prompt_messages.copy() | |||||
| prompt_messages.append( | |||||
| ToolPromptMessage( | |||||
| content=tool_response, | |||||
| tool_call_id=tool_call_id, | |||||
| name=tool_call_name, | |||||
| ) | |||||
| prompt_messages.append(UserPromptMessage(content=query)) | |||||
| return prompt_messages | |||||
| def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, | |||||
| prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: | |||||
| """ | |||||
| Organize assistant message | |||||
| """ | |||||
| prompt_messages = deepcopy(prompt_messages) | |||||
| if tool_response is not None: | |||||
| prompt_messages.append( | |||||
| ToolPromptMessage( | |||||
| content=tool_response, | |||||
| tool_call_id=tool_call_id, | |||||
| name=tool_call_name, | |||||
| ) | ) | ||||
| ) | |||||
| return prompt_messages | |||||
| def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: | |||||
| """ | |||||
| As for now, gpt supports both fc and vision at the first iteration. | |||||
| We need to remove the image messages from the prompt messages at the first iteration. | |||||
| """ | |||||
| prompt_messages = deepcopy(prompt_messages) | |||||
| for prompt_message in prompt_messages: | |||||
| if isinstance(prompt_message, UserPromptMessage): | |||||
| if isinstance(prompt_message.content, list): | |||||
| prompt_message.content = '\n'.join([ | |||||
| content.data if content.type == PromptMessageContentType.TEXT else | |||||
| '[image]' if content.type == PromptMessageContentType.IMAGE else | |||||
| '[file]' | |||||
| for content in prompt_message.content | |||||
| ]) | |||||
| return prompt_messages | return prompt_messages | 
| assistant_cot_runner = CotAgentRunner( | assistant_cot_runner = CotAgentRunner( | ||||
| tenant_id=app_config.tenant_id, | tenant_id=app_config.tenant_id, | ||||
| application_generate_entity=application_generate_entity, | application_generate_entity=application_generate_entity, | ||||
| conversation=conversation, | |||||
| app_config=app_config, | app_config=app_config, | ||||
| model_config=application_generate_entity.model_config, | model_config=application_generate_entity.model_config, | ||||
| config=agent_entity, | config=agent_entity, | ||||
| model_instance=model_instance | model_instance=model_instance | ||||
| ) | ) | ||||
| invoke_result = assistant_cot_runner.run( | invoke_result = assistant_cot_runner.run( | ||||
| conversation=conversation, | |||||
| message=message, | message=message, | ||||
| query=query, | query=query, | ||||
| inputs=inputs, | inputs=inputs, | ||||
| assistant_fc_runner = FunctionCallAgentRunner( | assistant_fc_runner = FunctionCallAgentRunner( | ||||
| tenant_id=app_config.tenant_id, | tenant_id=app_config.tenant_id, | ||||
| application_generate_entity=application_generate_entity, | application_generate_entity=application_generate_entity, | ||||
| conversation=conversation, | |||||
| app_config=app_config, | app_config=app_config, | ||||
| model_config=application_generate_entity.model_config, | model_config=application_generate_entity.model_config, | ||||
| config=agent_entity, | config=agent_entity, | ||||
| model_instance=model_instance | model_instance=model_instance | ||||
| ) | ) | ||||
| invoke_result = assistant_fc_runner.run( | invoke_result = assistant_fc_runner.run( | ||||
| conversation=conversation, | |||||
| message=message, | message=message, | ||||
| query=query, | query=query, | ||||
| ) | ) | 
| if user: | if user: | ||||
| extra_model_kwargs['user'] = user | extra_model_kwargs['user'] = user | ||||
| # clear illegal prompt messages | |||||
| prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) | |||||
| # chat model | # chat model | ||||
| response = client.chat.completions.create( | response = client.chat.completions.create( | ||||
| messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], | messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], | ||||
| return tool_call | return tool_call | ||||
| def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: | |||||
| """ | |||||
| Clear illegal prompt messages for OpenAI API | |||||
| :param model: model name | |||||
| :param prompt_messages: prompt messages | |||||
| :return: cleaned prompt messages | |||||
| """ | |||||
| checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09'] | |||||
| if model in checklist: | |||||
| # count how many user messages are there | |||||
| user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)]) | |||||
| if user_message_count > 1: | |||||
| for prompt_message in prompt_messages: | |||||
| if isinstance(prompt_message, UserPromptMessage): | |||||
| if isinstance(prompt_message.content, list): | |||||
| prompt_message.content = '\n'.join([ | |||||
| item.data if item.type == PromptMessageContentType.TEXT else | |||||
| '[IMAGE]' if item.type == PromptMessageContentType.IMAGE else '' | |||||
| for item in prompt_message.content | |||||
| ]) | |||||
| return prompt_messages | |||||
| def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: | def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: | ||||
| """ | """ | ||||
| Convert PromptMessage to dict for OpenAI API | Convert PromptMessage to dict for OpenAI API | 
| // answer | // answer | ||||
| const responseItem: ChatItem = { | const responseItem: ChatItem = { | ||||
| id: `${Date.now()}`, | |||||
| id: placeholderAnswerId, | |||||
| content: '', | content: '', | ||||
| agent_thoughts: [], | agent_thoughts: [], | ||||
| message_files: [], | message_files: [], |