Co-authored-by: chenyongzhao <chenyz@mama.cn>tags/0.6.10
| self.files = application_generate_entity.files | self.files = application_generate_entity.files | ||||
| else: | else: | ||||
| self.files = [] | self.files = [] | ||||
| self.query = None | |||||
| self._current_thoughts: list[PromptMessage] = [] | |||||
| def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ | def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ | ||||
| -> AgentChatAppGenerateEntity: | -> AgentChatAppGenerateEntity: | ||||
| for message in messages: | for message in messages: | ||||
| if message.id == self.message.id: | if message.id == self.message.id: | ||||
| continue | continue | ||||
| result.append(self.organize_agent_user_prompt(message)) | result.append(self.organize_agent_user_prompt(message)) | ||||
| agent_thoughts: list[MessageAgentThought] = message.agent_thoughts | agent_thoughts: list[MessageAgentThought] = message.agent_thoughts | ||||
| if agent_thoughts: | if agent_thoughts: | ||||
| return UserPromptMessage(content=prompt_message_contents) | return UserPromptMessage(content=prompt_message_contents) | ||||
| else: | else: | ||||
| return UserPromptMessage(content=message.query) | return UserPromptMessage(content=message.query) | ||||
| ToolPromptMessage, | ToolPromptMessage, | ||||
| UserPromptMessage, | UserPromptMessage, | ||||
| ) | ) | ||||
| from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | |||||
| from core.tools.entities.tool_entities import ToolInvokeMeta | from core.tools.entities.tool_entities import ToolInvokeMeta | ||||
| from core.tools.tool.tool import Tool | from core.tools.tool.tool import Tool | ||||
| from core.tools.tool_engine import ToolEngine | from core.tools.tool_engine import ToolEngine | ||||
| return message | return message | ||||
| def _organize_historic_prompt_messages(self) -> list[PromptMessage]: | |||||
| def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]: | |||||
| """ | """ | ||||
| organize historic prompt messages | organize historic prompt messages | ||||
| """ | """ | ||||
| scratchpad: list[AgentScratchpadUnit] = [] | scratchpad: list[AgentScratchpadUnit] = [] | ||||
| current_scratchpad: AgentScratchpadUnit = None | current_scratchpad: AgentScratchpadUnit = None | ||||
| self.history_prompt_messages = AgentHistoryPromptTransform( | |||||
| model_config=self.model_config, | |||||
| prompt_messages=current_session_messages or [], | |||||
| history_messages=self.history_prompt_messages, | |||||
| memory=self.memory | |||||
| ).get_prompt() | |||||
| for message in self.history_prompt_messages: | for message in self.history_prompt_messages: | ||||
| if isinstance(message, AssistantPromptMessage): | if isinstance(message, AssistantPromptMessage): | ||||
| current_scratchpad = AgentScratchpadUnit( | current_scratchpad = AgentScratchpadUnit( |
| # organize system prompt | # organize system prompt | ||||
| system_message = self._organize_system_prompt() | system_message = self._organize_system_prompt() | ||||
| # organize historic prompt messages | |||||
| historic_messages = self._historic_prompt_messages | |||||
| # organize current assistant messages | # organize current assistant messages | ||||
| agent_scratchpad = self._agent_scratchpad | agent_scratchpad = self._agent_scratchpad | ||||
| if not agent_scratchpad: | if not agent_scratchpad: | ||||
| query_messages = UserPromptMessage(content=self._query) | query_messages = UserPromptMessage(content=self._query) | ||||
| if assistant_messages: | if assistant_messages: | ||||
| # organize historic prompt messages | |||||
| historic_messages = self._organize_historic_prompt_messages([ | |||||
| system_message, | |||||
| query_messages, | |||||
| *assistant_messages, | |||||
| UserPromptMessage(content='continue') | |||||
| ]) | |||||
| messages = [ | messages = [ | ||||
| system_message, | system_message, | ||||
| *historic_messages, | *historic_messages, | ||||
| UserPromptMessage(content='continue') | UserPromptMessage(content='continue') | ||||
| ] | ] | ||||
| else: | else: | ||||
| # organize historic prompt messages | |||||
| historic_messages = self._organize_historic_prompt_messages([system_message, query_messages]) | |||||
| messages = [system_message, *historic_messages, query_messages] | messages = [system_message, *historic_messages, query_messages] | ||||
| # join all messages | # join all messages |
| return system_prompt | return system_prompt | ||||
| def _organize_historic_prompt(self) -> str: | |||||
| def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str: | |||||
| """ | """ | ||||
| Organize historic prompt | Organize historic prompt | ||||
| """ | """ | ||||
| historic_prompt_messages = self._historic_prompt_messages | |||||
| historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages) | |||||
| historic_prompt = "" | historic_prompt = "" | ||||
| for message in historic_prompt_messages: | for message in historic_prompt_messages: |
| ToolPromptMessage, | ToolPromptMessage, | ||||
| UserPromptMessage, | UserPromptMessage, | ||||
| ) | ) | ||||
| from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | |||||
| from core.tools.entities.tool_entities import ToolInvokeMeta | from core.tools.entities.tool_entities import ToolInvokeMeta | ||||
| from core.tools.tool_engine import ToolEngine | from core.tools.tool_engine import ToolEngine | ||||
| from models.model import Message | from models.model import Message | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class FunctionCallAgentRunner(BaseAgentRunner): | class FunctionCallAgentRunner(BaseAgentRunner): | ||||
| def run(self, | def run(self, | ||||
| message: Message, query: str, **kwargs: Any | message: Message, query: str, **kwargs: Any | ||||
| ) -> Generator[LLMResultChunk, None, None]: | ) -> Generator[LLMResultChunk, None, None]: | ||||
| """ | """ | ||||
| Run FunctionCall agent application | Run FunctionCall agent application | ||||
| """ | """ | ||||
| self.query = query | |||||
| app_generate_entity = self.application_generate_entity | app_generate_entity = self.application_generate_entity | ||||
| app_config = self.app_config | app_config = self.app_config | ||||
| prompt_template = app_config.prompt_template.simple_prompt_template or '' | |||||
| prompt_messages = self.history_prompt_messages | |||||
| prompt_messages = self._init_system_message(prompt_template, prompt_messages) | |||||
| prompt_messages = self._organize_user_query(query, prompt_messages) | |||||
| # convert tools into ModelRuntime Tool format | # convert tools into ModelRuntime Tool format | ||||
| tool_instances, prompt_messages_tools = self._init_prompt_tools() | tool_instances, prompt_messages_tools = self._init_prompt_tools() | ||||
| ) | ) | ||||
| # recalc llm max tokens | # recalc llm max tokens | ||||
| prompt_messages = self._organize_prompt_messages() | |||||
| self.recalc_llm_max_tokens(self.model_config, prompt_messages) | self.recalc_llm_max_tokens(self.model_config, prompt_messages) | ||||
| # invoke model | # invoke model | ||||
| chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( | chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( | ||||
| else: | else: | ||||
| assistant_message.content = response | assistant_message.content = response | ||||
| prompt_messages.append(assistant_message) | |||||
| self._current_thoughts.append(assistant_message) | |||||
| # save thought | # save thought | ||||
| self.save_agent_thought( | self.save_agent_thought( | ||||
| } | } | ||||
| tool_responses.append(tool_response) | tool_responses.append(tool_response) | ||||
| prompt_messages = self._organize_assistant_message( | |||||
| tool_call_id=tool_call_id, | |||||
| tool_call_name=tool_call_name, | |||||
| tool_response=tool_response['tool_response'], | |||||
| prompt_messages=prompt_messages, | |||||
| ) | |||||
| if tool_response['tool_response'] is not None: | |||||
| self._current_thoughts.append( | |||||
| ToolPromptMessage( | |||||
| content=tool_response['tool_response'], | |||||
| tool_call_id=tool_call_id, | |||||
| name=tool_call_name, | |||||
| ) | |||||
| ) | |||||
| if len(tool_responses) > 0: | if len(tool_responses) > 0: | ||||
| # save agent thought | # save agent thought | ||||
| iteration_step += 1 | iteration_step += 1 | ||||
| prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) | |||||
| self.update_db_variables(self.variables_pool, self.db_variables_pool) | self.update_db_variables(self.variables_pool, self.db_variables_pool) | ||||
| # publish end event | # publish end event | ||||
| self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( | self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( | ||||
| return prompt_messages | return prompt_messages | ||||
| def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, | |||||
| prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: | |||||
| """ | |||||
| Organize assistant message | |||||
| """ | |||||
| prompt_messages = deepcopy(prompt_messages) | |||||
| if tool_response is not None: | |||||
| prompt_messages.append( | |||||
| ToolPromptMessage( | |||||
| content=tool_response, | |||||
| tool_call_id=tool_call_id, | |||||
| name=tool_call_name, | |||||
| ) | |||||
| ) | |||||
| return prompt_messages | |||||
| def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: | def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: | ||||
| """ | """ | ||||
| As for now, gpt supports both fc and vision at the first iteration. | As for now, gpt supports both fc and vision at the first iteration. | ||||
| for content in prompt_message.content | for content in prompt_message.content | ||||
| ]) | ]) | ||||
| return prompt_messages | |||||
| return prompt_messages | |||||
| def _organize_prompt_messages(self): | |||||
| prompt_template = self.app_config.prompt_template.simple_prompt_template or '' | |||||
| self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) | |||||
| query_prompt_messages = self._organize_user_query(self.query, []) | |||||
| self.history_prompt_messages = AgentHistoryPromptTransform( | |||||
| model_config=self.model_config, | |||||
| prompt_messages=[*query_prompt_messages, *self._current_thoughts], | |||||
| history_messages=self.history_prompt_messages, | |||||
| memory=self.memory | |||||
| ).get_prompt() | |||||
| prompt_messages = [ | |||||
| *self.history_prompt_messages, | |||||
| *query_prompt_messages, | |||||
| *self._current_thoughts | |||||
| ] | |||||
| if len(self._current_thoughts) != 0: | |||||
| # clear messages after the first iteration | |||||
| prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) | |||||
| return prompt_messages |
| from typing import Optional, cast | |||||
| from core.app.entities.app_invoke_entities import ( | |||||
| ModelConfigWithCredentialsEntity, | |||||
| ) | |||||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||||
| from core.model_runtime.entities.message_entities import ( | |||||
| PromptMessage, | |||||
| SystemPromptMessage, | |||||
| UserPromptMessage, | |||||
| ) | |||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||||
| from core.prompt.prompt_transform import PromptTransform | |||||
| class AgentHistoryPromptTransform(PromptTransform): | |||||
| """ | |||||
| History Prompt Transform for Agent App | |||||
| """ | |||||
| def __init__(self, | |||||
| model_config: ModelConfigWithCredentialsEntity, | |||||
| prompt_messages: list[PromptMessage], | |||||
| history_messages: list[PromptMessage], | |||||
| memory: Optional[TokenBufferMemory] = None, | |||||
| ): | |||||
| self.model_config = model_config | |||||
| self.prompt_messages = prompt_messages | |||||
| self.history_messages = history_messages | |||||
| self.memory = memory | |||||
| def get_prompt(self) -> list[PromptMessage]: | |||||
| prompt_messages = [] | |||||
| num_system = 0 | |||||
| for prompt_message in self.history_messages: | |||||
| if isinstance(prompt_message, SystemPromptMessage): | |||||
| prompt_messages.append(prompt_message) | |||||
| num_system += 1 | |||||
| if not self.memory: | |||||
| return prompt_messages | |||||
| max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config) | |||||
| model_type_instance = self.model_config.provider_model_bundle.model_type_instance | |||||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||||
| curr_message_tokens = model_type_instance.get_num_tokens( | |||||
| self.memory.model_instance.model, | |||||
| self.memory.model_instance.credentials, | |||||
| self.history_messages | |||||
| ) | |||||
| if curr_message_tokens <= max_token_limit: | |||||
| return self.history_messages | |||||
| # number of prompt has been appended in current message | |||||
| num_prompt = 0 | |||||
| # append prompt messages in desc order | |||||
| for prompt_message in self.history_messages[::-1]: | |||||
| if isinstance(prompt_message, SystemPromptMessage): | |||||
| continue | |||||
| prompt_messages.append(prompt_message) | |||||
| num_prompt += 1 | |||||
| # a message is start with UserPromptMessage | |||||
| if isinstance(prompt_message, UserPromptMessage): | |||||
| curr_message_tokens = model_type_instance.get_num_tokens( | |||||
| self.memory.model_instance.model, | |||||
| self.memory.model_instance.credentials, | |||||
| prompt_messages | |||||
| ) | |||||
| # if current message token is overflow, drop all the prompts in current message and break | |||||
| if curr_message_tokens > max_token_limit: | |||||
| prompt_messages = prompt_messages[:-num_prompt] | |||||
| break | |||||
| num_prompt = 0 | |||||
| # return prompt messages in asc order | |||||
| message_prompts = prompt_messages[num_system:] | |||||
| message_prompts.reverse() | |||||
| # merge system and message prompt | |||||
| prompt_messages = prompt_messages[:num_system] | |||||
| prompt_messages.extend(message_prompts) | |||||
| return prompt_messages |
| from unittest.mock import MagicMock | |||||
| from core.app.entities.app_invoke_entities import ( | |||||
| ModelConfigWithCredentialsEntity, | |||||
| ) | |||||
| from core.entities.provider_configuration import ProviderModelBundle | |||||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||||
| from core.model_runtime.entities.message_entities import ( | |||||
| AssistantPromptMessage, | |||||
| SystemPromptMessage, | |||||
| ToolPromptMessage, | |||||
| UserPromptMessage, | |||||
| ) | |||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||||
| from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | |||||
| from models.model import Conversation | |||||
| def test_get_prompt(): | |||||
| prompt_messages = [ | |||||
| SystemPromptMessage(content='System Template'), | |||||
| UserPromptMessage(content='User Query'), | |||||
| ] | |||||
| history_messages = [ | |||||
| SystemPromptMessage(content='System Prompt 1'), | |||||
| UserPromptMessage(content='User Prompt 1'), | |||||
| AssistantPromptMessage(content='Assistant Thought 1'), | |||||
| ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'), | |||||
| ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'), | |||||
| SystemPromptMessage(content='System Prompt 2'), | |||||
| UserPromptMessage(content='User Prompt 2'), | |||||
| AssistantPromptMessage(content='Assistant Thought 2'), | |||||
| ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'), | |||||
| ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'), | |||||
| UserPromptMessage(content='User Prompt 3'), | |||||
| AssistantPromptMessage(content='Assistant Thought 3'), | |||||
| ] | |||||
| # use message number instead of token for testing | |||||
| def side_effect_get_num_tokens(*args): | |||||
| return len(args[2]) | |||||
| large_language_model_mock = MagicMock(spec=LargeLanguageModel) | |||||
| large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens) | |||||
| provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) | |||||
| provider_model_bundle_mock.model_type_instance = large_language_model_mock | |||||
| model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) | |||||
| model_config_mock.model = 'openai' | |||||
| model_config_mock.credentials = {} | |||||
| model_config_mock.provider_model_bundle = provider_model_bundle_mock | |||||
| memory = TokenBufferMemory( | |||||
| conversation=Conversation(), | |||||
| model_instance=model_config_mock | |||||
| ) | |||||
| transform = AgentHistoryPromptTransform( | |||||
| model_config=model_config_mock, | |||||
| prompt_messages=prompt_messages, | |||||
| history_messages=history_messages, | |||||
| memory=memory | |||||
| ) | |||||
| max_token_limit = 5 | |||||
| transform._calculate_rest_token = MagicMock(return_value=max_token_limit) | |||||
| result = transform.get_prompt() | |||||
| assert len(result) <= max_token_limit | |||||
| assert len(result) == 4 | |||||
| max_token_limit = 20 | |||||
| transform._calculate_rest_token = MagicMock(return_value=max_token_limit) | |||||
| result = transform.get_prompt() | |||||
| assert len(result) <= max_token_limit | |||||
| assert len(result) == 12 |