|
|
|
@@ -1,5 +1,6 @@ |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import uuid |
|
|
|
from datetime import datetime |
|
|
|
from mimetypes import guess_extension |
|
|
|
from typing import Optional, Union, cast |
|
|
|
@@ -20,7 +21,14 @@ from core.file.message_file_parser import FileTransferMethod |
|
|
|
from core.memory.token_buffer_memory import TokenBufferMemory |
|
|
|
from core.model_manager import ModelInstance |
|
|
|
from core.model_runtime.entities.llm_entities import LLMUsage |
|
|
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool |
|
|
|
from core.model_runtime.entities.message_entities import ( |
|
|
|
AssistantPromptMessage, |
|
|
|
PromptMessage, |
|
|
|
PromptMessageTool, |
|
|
|
SystemPromptMessage, |
|
|
|
ToolPromptMessage, |
|
|
|
UserPromptMessage, |
|
|
|
) |
|
|
|
from core.model_runtime.entities.model_entities import ModelFeature |
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel |
|
|
|
from core.model_runtime.utils.encoders import jsonable_encoder |
|
|
|
@@ -77,7 +85,9 @@ class BaseAssistantApplicationRunner(AppRunner): |
|
|
|
self.message = message |
|
|
|
self.user_id = user_id |
|
|
|
self.memory = memory |
|
|
|
self.history_prompt_messages = prompt_messages |
|
|
|
self.history_prompt_messages = self.organize_agent_history( |
|
|
|
prompt_messages=prompt_messages or [] |
|
|
|
) |
|
|
|
self.variables_pool = variables_pool |
|
|
|
self.db_variables_pool = db_variables |
|
|
|
self.model_instance = model_instance |
|
|
|
@@ -504,17 +514,6 @@ class BaseAssistantApplicationRunner(AppRunner): |
|
|
|
agent_thought.tool_labels_str = json.dumps(labels) |
|
|
|
|
|
|
|
db.session.commit() |
|
|
|
|
|
|
|
def get_history_prompt_messages(self) -> list[PromptMessage]: |
|
|
|
""" |
|
|
|
Get history prompt messages |
|
|
|
""" |
|
|
|
if self.history_prompt_messages is None: |
|
|
|
self.history_prompt_messages = db.session.query(PromptMessage).filter( |
|
|
|
PromptMessage.message_id == self.message.id, |
|
|
|
).order_by(PromptMessage.position.asc()).all() |
|
|
|
|
|
|
|
return self.history_prompt_messages |
|
|
|
|
|
|
|
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: |
|
|
|
""" |
|
|
|
@@ -589,4 +588,54 @@ class BaseAssistantApplicationRunner(AppRunner): |
|
|
|
""" |
|
|
|
db_variables.updated_at = datetime.utcnow() |
|
|
|
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) |
|
|
|
db.session.commit() |
|
|
|
db.session.commit() |
|
|
|
|
|
|
|
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: |
|
|
|
""" |
|
|
|
Organize agent history |
|
|
|
""" |
|
|
|
result = [] |
|
|
|
# check if there is a system message in the beginning of the conversation |
|
|
|
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage): |
|
|
|
result.append(prompt_messages[0]) |
|
|
|
|
|
|
|
messages: list[Message] = db.session.query(Message).filter( |
|
|
|
Message.conversation_id == self.message.conversation_id, |
|
|
|
).order_by(Message.created_at.asc()).all() |
|
|
|
|
|
|
|
for message in messages: |
|
|
|
result.append(UserPromptMessage(content=message.query)) |
|
|
|
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts |
|
|
|
for agent_thought in agent_thoughts: |
|
|
|
tools = agent_thought.tool |
|
|
|
if tools: |
|
|
|
tools = tools.split(';') |
|
|
|
tool_calls: list[AssistantPromptMessage.ToolCall] = [] |
|
|
|
tool_call_response: list[ToolPromptMessage] = [] |
|
|
|
tool_inputs = json.loads(agent_thought.tool_input) |
|
|
|
for tool in tools: |
|
|
|
# generate a uuid for tool call |
|
|
|
tool_call_id = str(uuid.uuid4()) |
|
|
|
tool_calls.append(AssistantPromptMessage.ToolCall( |
|
|
|
id=tool_call_id, |
|
|
|
type='function', |
|
|
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name=tool, |
|
|
|
arguments=json.dumps(tool_inputs.get(tool, {})), |
|
|
|
) |
|
|
|
)) |
|
|
|
tool_call_response.append(ToolPromptMessage( |
|
|
|
content=agent_thought.observation, |
|
|
|
name=tool, |
|
|
|
tool_call_id=tool_call_id, |
|
|
|
)) |
|
|
|
|
|
|
|
result.extend([ |
|
|
|
AssistantPromptMessage( |
|
|
|
content=agent_thought.thought, |
|
|
|
tool_calls=tool_calls, |
|
|
|
), |
|
|
|
*tool_call_response |
|
|
|
]) |
|
|
|
|
|
|
|
return result |