Co-authored-by: Joel <iamjoel007@gmail.com>tags/0.6.2
| @@ -5,6 +5,7 @@ from datetime import datetime | |||
| from typing import Optional, Union, cast | |||
| from core.agent.entities import AgentEntity, AgentToolEntity | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.apps.base_app_runner import AppRunner | |||
| @@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import ( | |||
| ) | |||
| from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.file.message_file_parser import MessageFileParser | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| @@ -22,6 +24,7 @@ from core.model_runtime.entities.message_entities import ( | |||
| PromptMessage, | |||
| PromptMessageTool, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| ToolPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| @@ -37,7 +40,7 @@ from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| from core.tools.tool.tool import Tool | |||
| from core.tools.tool_manager import ToolManager | |||
| from extensions.ext_database import db | |||
| from models.model import Message, MessageAgentThought | |||
| from models.model import Conversation, Message, MessageAgentThought | |||
| from models.tools import ToolConversationVariables | |||
| logger = logging.getLogger(__name__) | |||
| @@ -45,6 +48,7 @@ logger = logging.getLogger(__name__) | |||
| class BaseAgentRunner(AppRunner): | |||
| def __init__(self, tenant_id: str, | |||
| application_generate_entity: AgentChatAppGenerateEntity, | |||
| conversation: Conversation, | |||
| app_config: AgentChatAppConfig, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| config: AgentEntity, | |||
| @@ -72,6 +76,7 @@ class BaseAgentRunner(AppRunner): | |||
| """ | |||
| self.tenant_id = tenant_id | |||
| self.application_generate_entity = application_generate_entity | |||
| self.conversation = conversation | |||
| self.app_config = app_config | |||
| self.model_config = model_config | |||
| self.config = config | |||
| @@ -118,6 +123,12 @@ class BaseAgentRunner(AppRunner): | |||
| else: | |||
| self.stream_tool_call = False | |||
| # check if model supports vision | |||
| if model_schema and ModelFeature.VISION in (model_schema.features or []): | |||
| self.files = application_generate_entity.files | |||
| else: | |||
| self.files = [] | |||
| def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ | |||
| -> AgentChatAppGenerateEntity: | |||
| """ | |||
| @@ -412,15 +423,19 @@ class BaseAgentRunner(AppRunner): | |||
| """ | |||
| result = [] | |||
| # check if there is a system message in the beginning of the conversation | |||
| if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage): | |||
| result.append(prompt_messages[0]) | |||
| for prompt_message in prompt_messages: | |||
| if isinstance(prompt_message, SystemPromptMessage): | |||
| result.append(prompt_message) | |||
| messages: list[Message] = db.session.query(Message).filter( | |||
| Message.conversation_id == self.message.conversation_id, | |||
| ).order_by(Message.created_at.asc()).all() | |||
| for message in messages: | |||
| result.append(UserPromptMessage(content=message.query)) | |||
| if message.id == self.message.id: | |||
| continue | |||
| result.append(self.organize_agent_user_prompt(message)) | |||
| agent_thoughts: list[MessageAgentThought] = message.agent_thoughts | |||
| if agent_thoughts: | |||
| for agent_thought in agent_thoughts: | |||
| @@ -471,3 +486,32 @@ class BaseAgentRunner(AppRunner): | |||
| db.session.close() | |||
| return result | |||
| def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: | |||
| message_file_parser = MessageFileParser( | |||
| tenant_id=self.tenant_id, | |||
| app_id=self.app_config.app_id, | |||
| ) | |||
| files = message.message_files | |||
| if files: | |||
| file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) | |||
| if file_extra_config: | |||
| file_objs = message_file_parser.transform_message_files( | |||
| files, | |||
| file_extra_config | |||
| ) | |||
| else: | |||
| file_objs = [] | |||
| if not file_objs: | |||
| return UserPromptMessage(content=message.query) | |||
| else: | |||
| prompt_message_contents = [TextPromptMessageContent(data=message.query)] | |||
| for file_obj in file_objs: | |||
| prompt_message_contents.append(file_obj.prompt_message_content) | |||
| return UserPromptMessage(content=prompt_message_contents) | |||
| else: | |||
| return UserPromptMessage(content=message.query) | |||
| @@ -19,15 +19,14 @@ from core.model_runtime.entities.message_entities import ( | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.tools.entities.tool_entities import ToolInvokeMeta | |||
| from core.tools.tool_engine import ToolEngine | |||
| from models.model import Conversation, Message | |||
| from models.model import Message | |||
| class CotAgentRunner(BaseAgentRunner): | |||
| _is_first_iteration = True | |||
| _ignore_observation_providers = ['wenxin'] | |||
| def run(self, conversation: Conversation, | |||
| message: Message, | |||
| def run(self, message: Message, | |||
| query: str, | |||
| inputs: dict[str, str], | |||
| ) -> Union[Generator, LLMResult]: | |||
| @@ -1,6 +1,7 @@ | |||
| import json | |||
| import logging | |||
| from collections.abc import Generator | |||
| from copy import deepcopy | |||
| from typing import Any, Union | |||
| from core.agent.base_agent_runner import BaseAgentRunner | |||
| @@ -10,20 +11,21 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, | |||
| from core.model_runtime.entities.message_entities import ( | |||
| AssistantPromptMessage, | |||
| PromptMessage, | |||
| PromptMessageContentType, | |||
| PromptMessageTool, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| ToolPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.tools.entities.tool_entities import ToolInvokeMeta | |||
| from core.tools.tool_engine import ToolEngine | |||
| from models.model import Conversation, Message, MessageAgentThought | |||
| from models.model import Message | |||
| logger = logging.getLogger(__name__) | |||
| class FunctionCallAgentRunner(BaseAgentRunner): | |||
| def run(self, conversation: Conversation, | |||
| message: Message, | |||
| def run(self, message: Message, | |||
| query: str, | |||
| ) -> Generator[LLMResultChunk, None, None]: | |||
| """ | |||
| @@ -35,11 +37,8 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| prompt_template = app_config.prompt_template.simple_prompt_template or '' | |||
| prompt_messages = self.history_prompt_messages | |||
| prompt_messages = self.organize_prompt_messages( | |||
| prompt_template=prompt_template, | |||
| query=query, | |||
| prompt_messages=prompt_messages | |||
| ) | |||
| prompt_messages = self._init_system_message(prompt_template, prompt_messages) | |||
| prompt_messages = self._organize_user_query(query, prompt_messages) | |||
| # convert tools into ModelRuntime Tool format | |||
| prompt_messages_tools: list[PromptMessageTool] = [] | |||
| @@ -68,7 +67,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| # continue to run until there is not any tool call | |||
| function_call_state = True | |||
| agent_thoughts: list[MessageAgentThought] = [] | |||
| llm_usage = { | |||
| 'usage': None | |||
| } | |||
| @@ -287,9 +285,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| } | |||
| tool_responses.append(tool_response) | |||
| prompt_messages = self.organize_prompt_messages( | |||
| prompt_template=prompt_template, | |||
| query=None, | |||
| prompt_messages = self._organize_assistant_message( | |||
| tool_call_id=tool_call_id, | |||
| tool_call_name=tool_call_name, | |||
| tool_response=tool_response['tool_response'], | |||
| @@ -324,6 +320,8 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| iteration_step += 1 | |||
| prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) | |||
| self.update_db_variables(self.variables_pool, self.db_variables_pool) | |||
| # publish end event | |||
| self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( | |||
| @@ -386,29 +384,68 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| return tool_calls | |||
| def organize_prompt_messages(self, prompt_template: str, | |||
| query: str = None, | |||
| tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, | |||
| prompt_messages: list[PromptMessage] = None | |||
| ) -> list[PromptMessage]: | |||
| def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: | |||
| """ | |||
| Organize prompt messages | |||
| Initialize system message | |||
| """ | |||
| if not prompt_messages: | |||
| prompt_messages = [ | |||
| if not prompt_messages and prompt_template: | |||
| return [ | |||
| SystemPromptMessage(content=prompt_template), | |||
| UserPromptMessage(content=query), | |||
| ] | |||
| if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: | |||
| prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) | |||
| return prompt_messages | |||
| def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: | |||
| """ | |||
| Organize user query | |||
| """ | |||
| if self.files: | |||
| prompt_message_contents = [TextPromptMessageContent(data=query)] | |||
| for file_obj in self.files: | |||
| prompt_message_contents.append(file_obj.prompt_message_content) | |||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | |||
| else: | |||
| if tool_response: | |||
| prompt_messages = prompt_messages.copy() | |||
| prompt_messages.append( | |||
| ToolPromptMessage( | |||
| content=tool_response, | |||
| tool_call_id=tool_call_id, | |||
| name=tool_call_name, | |||
| ) | |||
| prompt_messages.append(UserPromptMessage(content=query)) | |||
| return prompt_messages | |||
| def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, | |||
| prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: | |||
| """ | |||
| Organize assistant message | |||
| """ | |||
| prompt_messages = deepcopy(prompt_messages) | |||
| if tool_response is not None: | |||
| prompt_messages.append( | |||
| ToolPromptMessage( | |||
| content=tool_response, | |||
| tool_call_id=tool_call_id, | |||
| name=tool_call_name, | |||
| ) | |||
| ) | |||
| return prompt_messages | |||
| def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: | |||
| """ | |||
| As for now, gpt supports both fc and vision at the first iteration. | |||
| We need to remove the image messages from the prompt messages at the first iteration. | |||
| """ | |||
| prompt_messages = deepcopy(prompt_messages) | |||
| for prompt_message in prompt_messages: | |||
| if isinstance(prompt_message, UserPromptMessage): | |||
| if isinstance(prompt_message.content, list): | |||
| prompt_message.content = '\n'.join([ | |||
| content.data if content.type == PromptMessageContentType.TEXT else | |||
| '[image]' if content.type == PromptMessageContentType.IMAGE else | |||
| '[file]' | |||
| for content in prompt_message.content | |||
| ]) | |||
| return prompt_messages | |||
| @@ -210,6 +210,7 @@ class AgentChatAppRunner(AppRunner): | |||
| assistant_cot_runner = CotAgentRunner( | |||
| tenant_id=app_config.tenant_id, | |||
| application_generate_entity=application_generate_entity, | |||
| conversation=conversation, | |||
| app_config=app_config, | |||
| model_config=application_generate_entity.model_config, | |||
| config=agent_entity, | |||
| @@ -223,7 +224,6 @@ class AgentChatAppRunner(AppRunner): | |||
| model_instance=model_instance | |||
| ) | |||
| invoke_result = assistant_cot_runner.run( | |||
| conversation=conversation, | |||
| message=message, | |||
| query=query, | |||
| inputs=inputs, | |||
| @@ -232,6 +232,7 @@ class AgentChatAppRunner(AppRunner): | |||
| assistant_fc_runner = FunctionCallAgentRunner( | |||
| tenant_id=app_config.tenant_id, | |||
| application_generate_entity=application_generate_entity, | |||
| conversation=conversation, | |||
| app_config=app_config, | |||
| model_config=application_generate_entity.model_config, | |||
| config=agent_entity, | |||
| @@ -245,7 +246,6 @@ class AgentChatAppRunner(AppRunner): | |||
| model_instance=model_instance | |||
| ) | |||
| invoke_result = assistant_fc_runner.run( | |||
| conversation=conversation, | |||
| message=message, | |||
| query=query, | |||
| ) | |||
| @@ -547,6 +547,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): | |||
| if user: | |||
| extra_model_kwargs['user'] = user | |||
| # clear illegal prompt messages | |||
| prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) | |||
| # chat model | |||
| response = client.chat.completions.create( | |||
| messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], | |||
| @@ -757,6 +760,31 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): | |||
| return tool_call | |||
| def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: | |||
| """ | |||
| Clear illegal prompt messages for OpenAI API | |||
| :param model: model name | |||
| :param prompt_messages: prompt messages | |||
| :return: cleaned prompt messages | |||
| """ | |||
| checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09'] | |||
| if model in checklist: | |||
| # count how many user messages are there | |||
| user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)]) | |||
| if user_message_count > 1: | |||
| for prompt_message in prompt_messages: | |||
| if isinstance(prompt_message, UserPromptMessage): | |||
| if isinstance(prompt_message.content, list): | |||
| prompt_message.content = '\n'.join([ | |||
| item.data if item.type == PromptMessageContentType.TEXT else | |||
| '[IMAGE]' if item.type == PromptMessageContentType.IMAGE else '' | |||
| for item in prompt_message.content | |||
| ]) | |||
| return prompt_messages | |||
| def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: | |||
| """ | |||
| Convert PromptMessage to dict for OpenAI API | |||
| @@ -229,7 +229,7 @@ export const useChat = ( | |||
| // answer | |||
| const responseItem: ChatItem = { | |||
| id: `${Date.now()}`, | |||
| id: placeholderAnswerId, | |||
| content: '', | |||
| agent_thoughts: [], | |||
| message_files: [], | |||