|
|
|
@@ -5,6 +5,7 @@ from core.model_runtime.entities.message_entities import ( |
|
|
|
AssistantPromptMessage, |
|
|
|
PromptMessage, |
|
|
|
SystemPromptMessage, |
|
|
|
TextPromptMessageContent, |
|
|
|
UserPromptMessage, |
|
|
|
) |
|
|
|
from core.model_runtime.utils.encoders import jsonable_encoder |
|
|
|
@@ -25,6 +26,21 @@ class CotChatAgentRunner(CotAgentRunner): |
|
|
|
|
|
|
|
return SystemPromptMessage(content=system_prompt) |
|
|
|
|
|
|
|
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: |
|
|
|
""" |
|
|
|
Organize user query |
|
|
|
""" |
|
|
|
if self.files: |
|
|
|
prompt_message_contents = [TextPromptMessageContent(data=query)] |
|
|
|
for file_obj in self.files: |
|
|
|
prompt_message_contents.append(file_obj.prompt_message_content) |
|
|
|
|
|
|
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) |
|
|
|
else: |
|
|
|
prompt_messages.append(UserPromptMessage(content=query)) |
|
|
|
|
|
|
|
return prompt_messages |
|
|
|
|
|
|
|
def _organize_prompt_messages(self) -> list[PromptMessage]: |
|
|
|
""" |
|
|
|
Organize |
|
|
|
@@ -51,27 +67,27 @@ class CotChatAgentRunner(CotAgentRunner): |
|
|
|
assistant_messages = [assistant_message] |
|
|
|
|
|
|
|
# query messages |
|
|
|
query_messages = UserPromptMessage(content=self._query) |
|
|
|
query_messages = self._organize_user_query(self._query, []) |
|
|
|
|
|
|
|
if assistant_messages: |
|
|
|
# organize historic prompt messages |
|
|
|
historic_messages = self._organize_historic_prompt_messages([ |
|
|
|
system_message, |
|
|
|
query_messages, |
|
|
|
*query_messages, |
|
|
|
*assistant_messages, |
|
|
|
UserPromptMessage(content='continue') |
|
|
|
]) |
|
|
|
]) |
|
|
|
messages = [ |
|
|
|
system_message, |
|
|
|
*historic_messages, |
|
|
|
query_messages, |
|
|
|
*query_messages, |
|
|
|
*assistant_messages, |
|
|
|
UserPromptMessage(content='continue') |
|
|
|
] |
|
|
|
else: |
|
|
|
# organize historic prompt messages |
|
|
|
historic_messages = self._organize_historic_prompt_messages([system_message, query_messages]) |
|
|
|
messages = [system_message, *historic_messages, query_messages] |
|
|
|
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages]) |
|
|
|
messages = [system_message, *historic_messages, *query_messages] |
|
|
|
|
|
|
|
# join all messages |
|
|
|
return messages |
|
|
|
return messages |