瀏覽代碼

fix: resolve issue with cot_agent_runner not analyzing user-uploaded images correctly (#5360)

Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
tags/0.6.12
Xiao Ley 1 年之前
父節點
當前提交
369a395ee9
沒有連結到貢獻者的電子郵件帳戶。
共有 2 個檔案被更改,包括 23 行新增9 行删除
  1. 0
    2
      api/core/agent/cot_agent_runner.py
  2. 23
    7
      api/core/agent/cot_chat_agent_runner.py

+ 0
- 2
api/core/agent/cot_agent_runner.py 查看文件

# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
tool_instances, self._prompt_messages_tools = self._init_prompt_tools() tool_instances, self._prompt_messages_tools = self._init_prompt_tools()


prompt_messages = self._organize_prompt_messages()

function_call_state = True function_call_state = True
llm_usage = { llm_usage = {
'usage': None 'usage': None

+ 23
- 7
api/core/agent/cot_chat_agent_runner.py 查看文件

AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder


return SystemPromptMessage(content=system_prompt) return SystemPromptMessage(content=system_prompt)


def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content)

prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))

return prompt_messages

def _organize_prompt_messages(self) -> list[PromptMessage]: def _organize_prompt_messages(self) -> list[PromptMessage]:
""" """
Organize Organize
assistant_messages = [assistant_message] assistant_messages = [assistant_message]


# query messages # query messages
query_messages = UserPromptMessage(content=self._query)
query_messages = self._organize_user_query(self._query, [])


if assistant_messages: if assistant_messages:
# organize historic prompt messages # organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([ historic_messages = self._organize_historic_prompt_messages([
system_message, system_message,
query_messages,
*query_messages,
*assistant_messages, *assistant_messages,
UserPromptMessage(content='continue') UserPromptMessage(content='continue')
])
])
messages = [ messages = [
system_message, system_message,
*historic_messages, *historic_messages,
query_messages,
*query_messages,
*assistant_messages, *assistant_messages,
UserPromptMessage(content='continue') UserPromptMessage(content='continue')
] ]
else: else:
# organize historic prompt messages # organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([system_message, query_messages])
messages = [system_message, *historic_messages, query_messages]
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
messages = [system_message, *historic_messages, *query_messages]


# join all messages # join all messages
return messages
return messages

Loading…
取消
儲存