|
|
|
@@ -1,6 +1,7 @@ |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import uuid |
|
|
|
from collections.abc import Mapping, Sequence |
|
|
|
from datetime import datetime, timezone |
|
|
|
from typing import Optional, Union, cast |
|
|
|
|
|
|
|
@@ -45,22 +46,25 @@ from models.tools import ToolConversationVariables |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class BaseAgentRunner(AppRunner): |
|
|
|
def __init__(self, tenant_id: str, |
|
|
|
application_generate_entity: AgentChatAppGenerateEntity, |
|
|
|
conversation: Conversation, |
|
|
|
app_config: AgentChatAppConfig, |
|
|
|
model_config: ModelConfigWithCredentialsEntity, |
|
|
|
config: AgentEntity, |
|
|
|
queue_manager: AppQueueManager, |
|
|
|
message: Message, |
|
|
|
user_id: str, |
|
|
|
memory: Optional[TokenBufferMemory] = None, |
|
|
|
prompt_messages: Optional[list[PromptMessage]] = None, |
|
|
|
variables_pool: Optional[ToolRuntimeVariablePool] = None, |
|
|
|
db_variables: Optional[ToolConversationVariables] = None, |
|
|
|
model_instance: ModelInstance = None |
|
|
|
) -> None: |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
tenant_id: str, |
|
|
|
application_generate_entity: AgentChatAppGenerateEntity, |
|
|
|
conversation: Conversation, |
|
|
|
app_config: AgentChatAppConfig, |
|
|
|
model_config: ModelConfigWithCredentialsEntity, |
|
|
|
config: AgentEntity, |
|
|
|
queue_manager: AppQueueManager, |
|
|
|
message: Message, |
|
|
|
user_id: str, |
|
|
|
memory: Optional[TokenBufferMemory] = None, |
|
|
|
prompt_messages: Optional[list[PromptMessage]] = None, |
|
|
|
variables_pool: Optional[ToolRuntimeVariablePool] = None, |
|
|
|
db_variables: Optional[ToolConversationVariables] = None, |
|
|
|
model_instance: ModelInstance = None, |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Agent runner |
|
|
|
:param tenant_id: tenant id |
|
|
|
@@ -88,9 +92,7 @@ class BaseAgentRunner(AppRunner): |
|
|
|
self.message = message |
|
|
|
self.user_id = user_id |
|
|
|
self.memory = memory |
|
|
|
self.history_prompt_messages = self.organize_agent_history( |
|
|
|
prompt_messages=prompt_messages or [] |
|
|
|
) |
|
|
|
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) |
|
|
|
self.variables_pool = variables_pool |
|
|
|
self.db_variables_pool = db_variables |
|
|
|
self.model_instance = model_instance |
|
|
|
@@ -111,12 +113,16 @@ class BaseAgentRunner(AppRunner): |
|
|
|
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, |
|
|
|
return_resource=app_config.additional_features.show_retrieve_source, |
|
|
|
invoke_from=application_generate_entity.invoke_from, |
|
|
|
hit_callback=hit_callback |
|
|
|
hit_callback=hit_callback, |
|
|
|
) |
|
|
|
# get how many agent thoughts have been created |
|
|
|
self.agent_thought_count = db.session.query(MessageAgentThought).filter( |
|
|
|
MessageAgentThought.message_id == self.message.id, |
|
|
|
).count() |
|
|
|
self.agent_thought_count = ( |
|
|
|
db.session.query(MessageAgentThought) |
|
|
|
.filter( |
|
|
|
MessageAgentThought.message_id == self.message.id, |
|
|
|
) |
|
|
|
.count() |
|
|
|
) |
|
|
|
db.session.close() |
|
|
|
|
|
|
|
# check if model supports stream tool call |
|
|
|
@@ -135,25 +141,26 @@ class BaseAgentRunner(AppRunner): |
|
|
|
self.query = None |
|
|
|
self._current_thoughts: list[PromptMessage] = [] |
|
|
|
|
|
|
|
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ |
|
|
|
-> AgentChatAppGenerateEntity: |
|
|
|
def _repack_app_generate_entity( |
|
|
|
self, app_generate_entity: AgentChatAppGenerateEntity |
|
|
|
) -> AgentChatAppGenerateEntity: |
|
|
|
""" |
|
|
|
Repack app generate entity |
|
|
|
""" |
|
|
|
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: |
|
|
|
app_generate_entity.app_config.prompt_template.simple_prompt_template = '' |
|
|
|
app_generate_entity.app_config.prompt_template.simple_prompt_template = "" |
|
|
|
|
|
|
|
return app_generate_entity |
|
|
|
|
|
|
|
|
|
|
|
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: |
|
|
|
""" |
|
|
|
convert tool to prompt message tool |
|
|
|
convert tool to prompt message tool |
|
|
|
""" |
|
|
|
tool_entity = ToolManager.get_agent_tool_runtime( |
|
|
|
tenant_id=self.tenant_id, |
|
|
|
app_id=self.app_config.app_id, |
|
|
|
agent_tool=tool, |
|
|
|
invoke_from=self.application_generate_entity.invoke_from |
|
|
|
invoke_from=self.application_generate_entity.invoke_from, |
|
|
|
) |
|
|
|
tool_entity.load_variables(self.variables_pool) |
|
|
|
|
|
|
|
@@ -164,7 +171,7 @@ class BaseAgentRunner(AppRunner): |
|
|
|
"type": "object", |
|
|
|
"properties": {}, |
|
|
|
"required": [], |
|
|
|
} |
|
|
|
}, |
|
|
|
) |
|
|
|
|
|
|
|
parameters = tool_entity.get_all_runtime_parameters() |
|
|
|
@@ -177,19 +184,19 @@ class BaseAgentRunner(AppRunner): |
|
|
|
if parameter.type == ToolParameter.ToolParameterType.SELECT: |
|
|
|
enum = [option.value for option in parameter.options] |
|
|
|
|
|
|
|
message_tool.parameters['properties'][parameter.name] = { |
|
|
|
message_tool.parameters["properties"][parameter.name] = { |
|
|
|
"type": parameter_type, |
|
|
|
"description": parameter.llm_description or '', |
|
|
|
"description": parameter.llm_description or "", |
|
|
|
} |
|
|
|
|
|
|
|
if len(enum) > 0: |
|
|
|
message_tool.parameters['properties'][parameter.name]['enum'] = enum |
|
|
|
message_tool.parameters["properties"][parameter.name]["enum"] = enum |
|
|
|
|
|
|
|
if parameter.required: |
|
|
|
message_tool.parameters['required'].append(parameter.name) |
|
|
|
message_tool.parameters["required"].append(parameter.name) |
|
|
|
|
|
|
|
return message_tool, tool_entity |
|
|
|
|
|
|
|
|
|
|
|
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: |
|
|
|
""" |
|
|
|
convert dataset retriever tool to prompt message tool |
|
|
|
@@ -201,24 +208,24 @@ class BaseAgentRunner(AppRunner): |
|
|
|
"type": "object", |
|
|
|
"properties": {}, |
|
|
|
"required": [], |
|
|
|
} |
|
|
|
}, |
|
|
|
) |
|
|
|
|
|
|
|
for parameter in tool.get_runtime_parameters(): |
|
|
|
parameter_type = 'string' |
|
|
|
|
|
|
|
prompt_tool.parameters['properties'][parameter.name] = { |
|
|
|
parameter_type = "string" |
|
|
|
|
|
|
|
prompt_tool.parameters["properties"][parameter.name] = { |
|
|
|
"type": parameter_type, |
|
|
|
"description": parameter.llm_description or '', |
|
|
|
"description": parameter.llm_description or "", |
|
|
|
} |
|
|
|
|
|
|
|
if parameter.required: |
|
|
|
if parameter.name not in prompt_tool.parameters['required']: |
|
|
|
prompt_tool.parameters['required'].append(parameter.name) |
|
|
|
if parameter.name not in prompt_tool.parameters["required"]: |
|
|
|
prompt_tool.parameters["required"].append(parameter.name) |
|
|
|
|
|
|
|
return prompt_tool |
|
|
|
|
|
|
|
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: |
|
|
|
|
|
|
|
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: |
|
|
|
""" |
|
|
|
Init tools |
|
|
|
""" |
|
|
|
@@ -261,51 +268,51 @@ class BaseAgentRunner(AppRunner): |
|
|
|
enum = [] |
|
|
|
if parameter.type == ToolParameter.ToolParameterType.SELECT: |
|
|
|
enum = [option.value for option in parameter.options] |
|
|
|
|
|
|
|
prompt_tool.parameters['properties'][parameter.name] = { |
|
|
|
|
|
|
|
prompt_tool.parameters["properties"][parameter.name] = { |
|
|
|
"type": parameter_type, |
|
|
|
"description": parameter.llm_description or '', |
|
|
|
"description": parameter.llm_description or "", |
|
|
|
} |
|
|
|
|
|
|
|
if len(enum) > 0: |
|
|
|
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum |
|
|
|
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum |
|
|
|
|
|
|
|
if parameter.required: |
|
|
|
if parameter.name not in prompt_tool.parameters['required']: |
|
|
|
prompt_tool.parameters['required'].append(parameter.name) |
|
|
|
if parameter.name not in prompt_tool.parameters["required"]: |
|
|
|
prompt_tool.parameters["required"].append(parameter.name) |
|
|
|
|
|
|
|
return prompt_tool |
|
|
|
|
|
|
|
def create_agent_thought(self, message_id: str, message: str, |
|
|
|
tool_name: str, tool_input: str, messages_ids: list[str] |
|
|
|
) -> MessageAgentThought: |
|
|
|
|
|
|
|
def create_agent_thought( |
|
|
|
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] |
|
|
|
) -> MessageAgentThought: |
|
|
|
""" |
|
|
|
Create agent thought |
|
|
|
""" |
|
|
|
thought = MessageAgentThought( |
|
|
|
message_id=message_id, |
|
|
|
message_chain_id=None, |
|
|
|
thought='', |
|
|
|
thought="", |
|
|
|
tool=tool_name, |
|
|
|
tool_labels_str='{}', |
|
|
|
tool_meta_str='{}', |
|
|
|
tool_labels_str="{}", |
|
|
|
tool_meta_str="{}", |
|
|
|
tool_input=tool_input, |
|
|
|
message=message, |
|
|
|
message_token=0, |
|
|
|
message_unit_price=0, |
|
|
|
message_price_unit=0, |
|
|
|
message_files=json.dumps(messages_ids) if messages_ids else '', |
|
|
|
answer='', |
|
|
|
observation='', |
|
|
|
message_files=json.dumps(messages_ids) if messages_ids else "", |
|
|
|
answer="", |
|
|
|
observation="", |
|
|
|
answer_token=0, |
|
|
|
answer_unit_price=0, |
|
|
|
answer_price_unit=0, |
|
|
|
tokens=0, |
|
|
|
total_price=0, |
|
|
|
position=self.agent_thought_count + 1, |
|
|
|
currency='USD', |
|
|
|
currency="USD", |
|
|
|
latency=0, |
|
|
|
created_by_role='account', |
|
|
|
created_by_role="account", |
|
|
|
created_by=self.user_id, |
|
|
|
) |
|
|
|
|
|
|
|
@@ -318,22 +325,22 @@ class BaseAgentRunner(AppRunner): |
|
|
|
|
|
|
|
return thought |
|
|
|
|
|
|
|
def save_agent_thought(self, |
|
|
|
agent_thought: MessageAgentThought, |
|
|
|
tool_name: str, |
|
|
|
tool_input: Union[str, dict], |
|
|
|
thought: str, |
|
|
|
observation: Union[str, dict], |
|
|
|
tool_invoke_meta: Union[str, dict], |
|
|
|
answer: str, |
|
|
|
messages_ids: list[str], |
|
|
|
llm_usage: LLMUsage = None) -> MessageAgentThought: |
|
|
|
def save_agent_thought( |
|
|
|
self, |
|
|
|
agent_thought: MessageAgentThought, |
|
|
|
tool_name: str, |
|
|
|
tool_input: Union[str, dict], |
|
|
|
thought: str, |
|
|
|
observation: Union[str, dict], |
|
|
|
tool_invoke_meta: Union[str, dict], |
|
|
|
answer: str, |
|
|
|
messages_ids: list[str], |
|
|
|
llm_usage: LLMUsage = None, |
|
|
|
) -> MessageAgentThought: |
|
|
|
""" |
|
|
|
Save agent thought |
|
|
|
""" |
|
|
|
agent_thought = db.session.query(MessageAgentThought).filter( |
|
|
|
MessageAgentThought.id == agent_thought.id |
|
|
|
).first() |
|
|
|
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() |
|
|
|
|
|
|
|
if thought is not None: |
|
|
|
agent_thought.thought = thought |
|
|
|
@@ -356,7 +363,7 @@ class BaseAgentRunner(AppRunner): |
|
|
|
observation = json.dumps(observation, ensure_ascii=False) |
|
|
|
except Exception as e: |
|
|
|
observation = json.dumps(observation) |
|
|
|
|
|
|
|
|
|
|
|
agent_thought.observation = observation |
|
|
|
|
|
|
|
if answer is not None: |
|
|
|
@@ -364,7 +371,7 @@ class BaseAgentRunner(AppRunner): |
|
|
|
|
|
|
|
if messages_ids is not None and len(messages_ids) > 0: |
|
|
|
agent_thought.message_files = json.dumps(messages_ids) |
|
|
|
|
|
|
|
|
|
|
|
if llm_usage: |
|
|
|
agent_thought.message_token = llm_usage.prompt_tokens |
|
|
|
agent_thought.message_price_unit = llm_usage.prompt_price_unit |
|
|
|
@@ -377,7 +384,7 @@ class BaseAgentRunner(AppRunner): |
|
|
|
|
|
|
|
# check if tool labels is not empty |
|
|
|
labels = agent_thought.tool_labels or {} |
|
|
|
tools = agent_thought.tool.split(';') if agent_thought.tool else [] |
|
|
|
tools = agent_thought.tool.split(";") if agent_thought.tool else [] |
|
|
|
for tool in tools: |
|
|
|
if not tool: |
|
|
|
continue |
|
|
|
@@ -386,7 +393,7 @@ class BaseAgentRunner(AppRunner): |
|
|
|
if tool_label: |
|
|
|
labels[tool] = tool_label.to_dict() |
|
|
|
else: |
|
|
|
labels[tool] = {'en_US': tool, 'zh_Hans': tool} |
|
|
|
labels[tool] = {"en_US": tool, "zh_Hans": tool} |
|
|
|
|
|
|
|
agent_thought.tool_labels_str = json.dumps(labels) |
|
|
|
|
|
|
|
@@ -401,14 +408,18 @@ class BaseAgentRunner(AppRunner): |
|
|
|
|
|
|
|
db.session.commit() |
|
|
|
db.session.close() |
|
|
|
|
|
|
|
|
|
|
|
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): |
|
|
|
""" |
|
|
|
convert tool variables to db variables |
|
|
|
""" |
|
|
|
db_variables = db.session.query(ToolConversationVariables).filter( |
|
|
|
ToolConversationVariables.conversation_id == self.message.conversation_id, |
|
|
|
).first() |
|
|
|
db_variables = ( |
|
|
|
db.session.query(ToolConversationVariables) |
|
|
|
.filter( |
|
|
|
ToolConversationVariables.conversation_id == self.message.conversation_id, |
|
|
|
) |
|
|
|
.first() |
|
|
|
) |
|
|
|
|
|
|
|
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) |
|
|
|
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) |
|
|
|
@@ -425,9 +436,14 @@ class BaseAgentRunner(AppRunner): |
|
|
|
if isinstance(prompt_message, SystemPromptMessage): |
|
|
|
result.append(prompt_message) |
|
|
|
|
|
|
|
messages: list[Message] = db.session.query(Message).filter( |
|
|
|
Message.conversation_id == self.message.conversation_id, |
|
|
|
).order_by(Message.created_at.asc()).all() |
|
|
|
messages: list[Message] = ( |
|
|
|
db.session.query(Message) |
|
|
|
.filter( |
|
|
|
Message.conversation_id == self.message.conversation_id, |
|
|
|
) |
|
|
|
.order_by(Message.created_at.asc()) |
|
|
|
.all() |
|
|
|
) |
|
|
|
|
|
|
|
for message in messages: |
|
|
|
if message.id == self.message.id: |
|
|
|
@@ -439,13 +455,13 @@ class BaseAgentRunner(AppRunner): |
|
|
|
for agent_thought in agent_thoughts: |
|
|
|
tools = agent_thought.tool |
|
|
|
if tools: |
|
|
|
tools = tools.split(';') |
|
|
|
tools = tools.split(";") |
|
|
|
tool_calls: list[AssistantPromptMessage.ToolCall] = [] |
|
|
|
tool_call_response: list[ToolPromptMessage] = [] |
|
|
|
try: |
|
|
|
tool_inputs = json.loads(agent_thought.tool_input) |
|
|
|
except Exception as e: |
|
|
|
tool_inputs = { tool: {} for tool in tools } |
|
|
|
tool_inputs = {tool: {} for tool in tools} |
|
|
|
try: |
|
|
|
tool_responses = json.loads(agent_thought.observation) |
|
|
|
except Exception as e: |
|
|
|
@@ -454,27 +470,33 @@ class BaseAgentRunner(AppRunner): |
|
|
|
for tool in tools: |
|
|
|
# generate a uuid for tool call |
|
|
|
tool_call_id = str(uuid.uuid4()) |
|
|
|
tool_calls.append(AssistantPromptMessage.ToolCall( |
|
|
|
id=tool_call_id, |
|
|
|
type='function', |
|
|
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
tool_calls.append( |
|
|
|
AssistantPromptMessage.ToolCall( |
|
|
|
id=tool_call_id, |
|
|
|
type="function", |
|
|
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name=tool, |
|
|
|
arguments=json.dumps(tool_inputs.get(tool, {})), |
|
|
|
), |
|
|
|
) |
|
|
|
) |
|
|
|
tool_call_response.append( |
|
|
|
ToolPromptMessage( |
|
|
|
content=tool_responses.get(tool, agent_thought.observation), |
|
|
|
name=tool, |
|
|
|
arguments=json.dumps(tool_inputs.get(tool, {})), |
|
|
|
tool_call_id=tool_call_id, |
|
|
|
) |
|
|
|
)) |
|
|
|
tool_call_response.append(ToolPromptMessage( |
|
|
|
content=tool_responses.get(tool, agent_thought.observation), |
|
|
|
name=tool, |
|
|
|
tool_call_id=tool_call_id, |
|
|
|
)) |
|
|
|
|
|
|
|
result.extend([ |
|
|
|
AssistantPromptMessage( |
|
|
|
content=agent_thought.thought, |
|
|
|
tool_calls=tool_calls, |
|
|
|
), |
|
|
|
*tool_call_response |
|
|
|
]) |
|
|
|
) |
|
|
|
|
|
|
|
result.extend( |
|
|
|
[ |
|
|
|
AssistantPromptMessage( |
|
|
|
content=agent_thought.thought, |
|
|
|
tool_calls=tool_calls, |
|
|
|
), |
|
|
|
*tool_call_response, |
|
|
|
] |
|
|
|
) |
|
|
|
if not tools: |
|
|
|
result.append(AssistantPromptMessage(content=agent_thought.thought)) |
|
|
|
else: |
|
|
|
@@ -496,10 +518,7 @@ class BaseAgentRunner(AppRunner): |
|
|
|
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) |
|
|
|
|
|
|
|
if file_extra_config: |
|
|
|
file_objs = message_file_parser.transform_message_files( |
|
|
|
files, |
|
|
|
file_extra_config |
|
|
|
) |
|
|
|
file_objs = message_file_parser.transform_message_files(files, file_extra_config) |
|
|
|
else: |
|
|
|
file_objs = [] |
|
|
|
|