|
|
|
@@ -19,8 +19,6 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner |
|
|
|
|
|
|
|
from models.model import Conversation, Message |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): |
|
|
|
def run(self, model_instance: ModelInstance, |
|
|
|
conversation: Conversation, |
|
|
|
@@ -93,6 +91,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): |
|
|
|
prompt_messages_tools = [] |
|
|
|
|
|
|
|
message_file_ids = [] |
|
|
|
|
|
|
|
agent_thought = self.create_agent_thought( |
|
|
|
message_id=message.id, |
|
|
|
message='', |
|
|
|
@@ -100,7 +99,9 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): |
|
|
|
tool_input='', |
|
|
|
messages_ids=message_file_ids |
|
|
|
) |
|
|
|
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) |
|
|
|
|
|
|
|
if iteration_step > 1: |
|
|
|
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) |
|
|
|
|
|
|
|
# update prompt messages |
|
|
|
prompt_messages = self._originze_cot_prompt_messages( |
|
|
|
@@ -137,7 +138,11 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): |
|
|
|
# get llm usage |
|
|
|
if llm_result.usage: |
|
|
|
increse_usage(llm_usage, llm_result.usage) |
|
|
|
|
|
|
|
|
|
|
|
# publish agent thought if it's first iteration |
|
|
|
if iteration_step == 1: |
|
|
|
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) |
|
|
|
|
|
|
|
self.save_agent_thought(agent_thought=agent_thought, |
|
|
|
tool_name=scratchpad.action.action_name if scratchpad.action else '', |
|
|
|
tool_input=scratchpad.action.action_input if scratchpad.action else '', |
|
|
|
@@ -187,7 +192,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): |
|
|
|
tool_call_args = scratchpad.action.action_input |
|
|
|
tool_instance = tool_instances.get(tool_call_name) |
|
|
|
if not tool_instance: |
|
|
|
logger.error(f"failed to find tool instance: {tool_call_name}") |
|
|
|
answer = f"there is not a tool named {tool_call_name}" |
|
|
|
self.save_agent_thought(agent_thought=agent_thought, |
|
|
|
tool_name='', |
|
|
|
@@ -237,7 +241,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): |
|
|
|
|
|
|
|
if error_response: |
|
|
|
observation = error_response |
|
|
|
logger.error(error_response) |
|
|
|
else: |
|
|
|
observation = self._convert_tool_response_to_str(tool_response) |
|
|
|
|
|
|
|
@@ -543,13 +546,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): |
|
|
|
# add assistant message |
|
|
|
if len(agent_scratchpad) > 0: |
|
|
|
prompt_messages.append(AssistantPromptMessage( |
|
|
|
content=(agent_scratchpad[-1].thought or '') + "\n" + (agent_scratchpad[-1].observation or '') |
|
|
|
content=(agent_scratchpad[-1].thought or '') |
|
|
|
)) |
|
|
|
|
|
|
|
|
|
|
|
# add user message |
|
|
|
if len(agent_scratchpad) > 0: |
|
|
|
prompt_messages.append(UserPromptMessage( |
|
|
|
content=input, |
|
|
|
content=(agent_scratchpad[-1].observation or ''), |
|
|
|
)) |
|
|
|
|
|
|
|
return prompt_messages |