Co-authored-by: chenyongzhao <chenyz@mama.cn>tags/0.6.10
| @@ -128,6 +128,8 @@ class BaseAgentRunner(AppRunner): | |||
| self.files = application_generate_entity.files | |||
| else: | |||
| self.files = [] | |||
| self.query = None | |||
| self._current_thoughts: list[PromptMessage] = [] | |||
| def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ | |||
| -> AgentChatAppGenerateEntity: | |||
| @@ -464,7 +466,7 @@ class BaseAgentRunner(AppRunner): | |||
| for message in messages: | |||
| if message.id == self.message.id: | |||
| continue | |||
| result.append(self.organize_agent_user_prompt(message)) | |||
| agent_thoughts: list[MessageAgentThought] = message.agent_thoughts | |||
| if agent_thoughts: | |||
| @@ -545,3 +547,4 @@ class BaseAgentRunner(AppRunner): | |||
| return UserPromptMessage(content=prompt_message_contents) | |||
| else: | |||
| return UserPromptMessage(content=message.query) | |||
| @@ -15,6 +15,7 @@ from core.model_runtime.entities.message_entities import ( | |||
| ToolPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | |||
| from core.tools.entities.tool_entities import ToolInvokeMeta | |||
| from core.tools.tool.tool import Tool | |||
| from core.tools.tool_engine import ToolEngine | |||
| @@ -373,7 +374,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| return message | |||
| def _organize_historic_prompt_messages(self) -> list[PromptMessage]: | |||
| def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]: | |||
| """ | |||
| organize historic prompt messages | |||
| """ | |||
| @@ -381,6 +382,13 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| scratchpad: list[AgentScratchpadUnit] = [] | |||
| current_scratchpad: AgentScratchpadUnit = None | |||
| self.history_prompt_messages = AgentHistoryPromptTransform( | |||
| model_config=self.model_config, | |||
| prompt_messages=current_session_messages or [], | |||
| history_messages=self.history_prompt_messages, | |||
| memory=self.memory | |||
| ).get_prompt() | |||
| for message in self.history_prompt_messages: | |||
| if isinstance(message, AssistantPromptMessage): | |||
| current_scratchpad = AgentScratchpadUnit( | |||
| @@ -32,9 +32,6 @@ class CotChatAgentRunner(CotAgentRunner): | |||
| # organize system prompt | |||
| system_message = self._organize_system_prompt() | |||
| # organize historic prompt messages | |||
| historic_messages = self._historic_prompt_messages | |||
| # organize current assistant messages | |||
| agent_scratchpad = self._agent_scratchpad | |||
| if not agent_scratchpad: | |||
| @@ -57,6 +54,13 @@ class CotChatAgentRunner(CotAgentRunner): | |||
| query_messages = UserPromptMessage(content=self._query) | |||
| if assistant_messages: | |||
| # organize historic prompt messages | |||
| historic_messages = self._organize_historic_prompt_messages([ | |||
| system_message, | |||
| query_messages, | |||
| *assistant_messages, | |||
| UserPromptMessage(content='continue') | |||
| ]) | |||
| messages = [ | |||
| system_message, | |||
| *historic_messages, | |||
| @@ -65,6 +69,8 @@ class CotChatAgentRunner(CotAgentRunner): | |||
| UserPromptMessage(content='continue') | |||
| ] | |||
| else: | |||
| # organize historic prompt messages | |||
| historic_messages = self._organize_historic_prompt_messages([system_message, query_messages]) | |||
| messages = [system_message, *historic_messages, query_messages] | |||
| # join all messages | |||
| @@ -19,11 +19,11 @@ class CotCompletionAgentRunner(CotAgentRunner): | |||
| return system_prompt | |||
| def _organize_historic_prompt(self) -> str: | |||
| def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str: | |||
| """ | |||
| Organize historic prompt | |||
| """ | |||
| historic_prompt_messages = self._historic_prompt_messages | |||
| historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages) | |||
| historic_prompt = "" | |||
| for message in historic_prompt_messages: | |||
| @@ -17,6 +17,7 @@ from core.model_runtime.entities.message_entities import ( | |||
| ToolPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | |||
| from core.tools.entities.tool_entities import ToolInvokeMeta | |||
| from core.tools.tool_engine import ToolEngine | |||
| from models.model import Message | |||
| @@ -24,21 +25,18 @@ from models.model import Message | |||
| logger = logging.getLogger(__name__) | |||
| class FunctionCallAgentRunner(BaseAgentRunner): | |||
| def run(self, | |||
| message: Message, query: str, **kwargs: Any | |||
| ) -> Generator[LLMResultChunk, None, None]: | |||
| """ | |||
| Run FunctionCall agent application | |||
| """ | |||
| self.query = query | |||
| app_generate_entity = self.application_generate_entity | |||
| app_config = self.app_config | |||
| prompt_template = app_config.prompt_template.simple_prompt_template or '' | |||
| prompt_messages = self.history_prompt_messages | |||
| prompt_messages = self._init_system_message(prompt_template, prompt_messages) | |||
| prompt_messages = self._organize_user_query(query, prompt_messages) | |||
| # convert tools into ModelRuntime Tool format | |||
| tool_instances, prompt_messages_tools = self._init_prompt_tools() | |||
| @@ -81,6 +79,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| ) | |||
| # recalc llm max tokens | |||
| prompt_messages = self._organize_prompt_messages() | |||
| self.recalc_llm_max_tokens(self.model_config, prompt_messages) | |||
| # invoke model | |||
| chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( | |||
| @@ -203,7 +202,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| else: | |||
| assistant_message.content = response | |||
| prompt_messages.append(assistant_message) | |||
| self._current_thoughts.append(assistant_message) | |||
| # save thought | |||
| self.save_agent_thought( | |||
| @@ -265,12 +264,14 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| } | |||
| tool_responses.append(tool_response) | |||
| prompt_messages = self._organize_assistant_message( | |||
| tool_call_id=tool_call_id, | |||
| tool_call_name=tool_call_name, | |||
| tool_response=tool_response['tool_response'], | |||
| prompt_messages=prompt_messages, | |||
| ) | |||
| if tool_response['tool_response'] is not None: | |||
| self._current_thoughts.append( | |||
| ToolPromptMessage( | |||
| content=tool_response['tool_response'], | |||
| tool_call_id=tool_call_id, | |||
| name=tool_call_name, | |||
| ) | |||
| ) | |||
| if len(tool_responses) > 0: | |||
| # save agent thought | |||
| @@ -300,8 +301,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| iteration_step += 1 | |||
| prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) | |||
| self.update_db_variables(self.variables_pool, self.db_variables_pool) | |||
| # publish end event | |||
| self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( | |||
| @@ -393,24 +392,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| return prompt_messages | |||
| def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, | |||
| prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: | |||
| """ | |||
| Organize assistant message | |||
| """ | |||
| prompt_messages = deepcopy(prompt_messages) | |||
| if tool_response is not None: | |||
| prompt_messages.append( | |||
| ToolPromptMessage( | |||
| content=tool_response, | |||
| tool_call_id=tool_call_id, | |||
| name=tool_call_name, | |||
| ) | |||
| ) | |||
| return prompt_messages | |||
| def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: | |||
| """ | |||
| As for now, gpt supports both fc and vision at the first iteration. | |||
| @@ -428,4 +409,26 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| for content in prompt_message.content | |||
| ]) | |||
| return prompt_messages | |||
| return prompt_messages | |||
| def _organize_prompt_messages(self): | |||
| prompt_template = self.app_config.prompt_template.simple_prompt_template or '' | |||
| self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) | |||
| query_prompt_messages = self._organize_user_query(self.query, []) | |||
| self.history_prompt_messages = AgentHistoryPromptTransform( | |||
| model_config=self.model_config, | |||
| prompt_messages=[*query_prompt_messages, *self._current_thoughts], | |||
| history_messages=self.history_prompt_messages, | |||
| memory=self.memory | |||
| ).get_prompt() | |||
| prompt_messages = [ | |||
| *self.history_prompt_messages, | |||
| *query_prompt_messages, | |||
| *self._current_thoughts | |||
| ] | |||
| if len(self._current_thoughts) != 0: | |||
| # clear messages after the first iteration | |||
| prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) | |||
| return prompt_messages | |||
| @@ -0,0 +1,82 @@ | |||
| from typing import Optional, cast | |||
| from core.app.entities.app_invoke_entities import ( | |||
| ModelConfigWithCredentialsEntity, | |||
| ) | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.message_entities import ( | |||
| PromptMessage, | |||
| SystemPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.prompt.prompt_transform import PromptTransform | |||
| class AgentHistoryPromptTransform(PromptTransform): | |||
| """ | |||
| History Prompt Transform for Agent App | |||
| """ | |||
| def __init__(self, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| prompt_messages: list[PromptMessage], | |||
| history_messages: list[PromptMessage], | |||
| memory: Optional[TokenBufferMemory] = None, | |||
| ): | |||
| self.model_config = model_config | |||
| self.prompt_messages = prompt_messages | |||
| self.history_messages = history_messages | |||
| self.memory = memory | |||
| def get_prompt(self) -> list[PromptMessage]: | |||
| prompt_messages = [] | |||
| num_system = 0 | |||
| for prompt_message in self.history_messages: | |||
| if isinstance(prompt_message, SystemPromptMessage): | |||
| prompt_messages.append(prompt_message) | |||
| num_system += 1 | |||
| if not self.memory: | |||
| return prompt_messages | |||
| max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config) | |||
| model_type_instance = self.model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| curr_message_tokens = model_type_instance.get_num_tokens( | |||
| self.memory.model_instance.model, | |||
| self.memory.model_instance.credentials, | |||
| self.history_messages | |||
| ) | |||
| if curr_message_tokens <= max_token_limit: | |||
| return self.history_messages | |||
| # number of prompt has been appended in current message | |||
| num_prompt = 0 | |||
| # append prompt messages in desc order | |||
| for prompt_message in self.history_messages[::-1]: | |||
| if isinstance(prompt_message, SystemPromptMessage): | |||
| continue | |||
| prompt_messages.append(prompt_message) | |||
| num_prompt += 1 | |||
| # a message is start with UserPromptMessage | |||
| if isinstance(prompt_message, UserPromptMessage): | |||
| curr_message_tokens = model_type_instance.get_num_tokens( | |||
| self.memory.model_instance.model, | |||
| self.memory.model_instance.credentials, | |||
| prompt_messages | |||
| ) | |||
| # if current message token is overflow, drop all the prompts in current message and break | |||
| if curr_message_tokens > max_token_limit: | |||
| prompt_messages = prompt_messages[:-num_prompt] | |||
| break | |||
| num_prompt = 0 | |||
| # return prompt messages in asc order | |||
| message_prompts = prompt_messages[num_system:] | |||
| message_prompts.reverse() | |||
| # merge system and message prompt | |||
| prompt_messages = prompt_messages[:num_system] | |||
| prompt_messages.extend(message_prompts) | |||
| return prompt_messages | |||
| @@ -0,0 +1,77 @@ | |||
| from unittest.mock import MagicMock | |||
| from core.app.entities.app_invoke_entities import ( | |||
| ModelConfigWithCredentialsEntity, | |||
| ) | |||
| from core.entities.provider_configuration import ProviderModelBundle | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.message_entities import ( | |||
| AssistantPromptMessage, | |||
| SystemPromptMessage, | |||
| ToolPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | |||
| from models.model import Conversation | |||
| def test_get_prompt(): | |||
| prompt_messages = [ | |||
| SystemPromptMessage(content='System Template'), | |||
| UserPromptMessage(content='User Query'), | |||
| ] | |||
| history_messages = [ | |||
| SystemPromptMessage(content='System Prompt 1'), | |||
| UserPromptMessage(content='User Prompt 1'), | |||
| AssistantPromptMessage(content='Assistant Thought 1'), | |||
| ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'), | |||
| ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'), | |||
| SystemPromptMessage(content='System Prompt 2'), | |||
| UserPromptMessage(content='User Prompt 2'), | |||
| AssistantPromptMessage(content='Assistant Thought 2'), | |||
| ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'), | |||
| ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'), | |||
| UserPromptMessage(content='User Prompt 3'), | |||
| AssistantPromptMessage(content='Assistant Thought 3'), | |||
| ] | |||
| # use message number instead of token for testing | |||
| def side_effect_get_num_tokens(*args): | |||
| return len(args[2]) | |||
| large_language_model_mock = MagicMock(spec=LargeLanguageModel) | |||
| large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens) | |||
| provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) | |||
| provider_model_bundle_mock.model_type_instance = large_language_model_mock | |||
| model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity) | |||
| model_config_mock.model = 'openai' | |||
| model_config_mock.credentials = {} | |||
| model_config_mock.provider_model_bundle = provider_model_bundle_mock | |||
| memory = TokenBufferMemory( | |||
| conversation=Conversation(), | |||
| model_instance=model_config_mock | |||
| ) | |||
| transform = AgentHistoryPromptTransform( | |||
| model_config=model_config_mock, | |||
| prompt_messages=prompt_messages, | |||
| history_messages=history_messages, | |||
| memory=memory | |||
| ) | |||
| max_token_limit = 5 | |||
| transform._calculate_rest_token = MagicMock(return_value=max_token_limit) | |||
| result = transform.get_prompt() | |||
| assert len(result) <= max_token_limit | |||
| assert len(result) == 4 | |||
| max_token_limit = 20 | |||
| transform._calculate_rest_token = MagicMock(return_value=max_token_limit) | |||
| result = transform.get_prompt() | |||
| assert len(result) <= max_token_limit | |||
| assert len(result) == 12 | |||