| @@ -44,6 +44,7 @@ class MessageListApi(AppApiResource): | |||
| 'position': fields.Integer, | |||
| 'thought': fields.String, | |||
| 'tool': fields.String, | |||
| 'tool_labels': fields.Raw, | |||
| 'tool_input': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'observation': fields.String, | |||
| @@ -18,6 +18,7 @@ from core.model_runtime.entities.message_entities import (AssistantPromptMessage | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.prompt.prompt_template import PromptTemplateParser | |||
| from events.message_event import message_was_created | |||
| @@ -281,7 +282,7 @@ class GenerateTaskPipeline: | |||
| self._task_state.llm_result.message.content = annotation.content | |||
| elif isinstance(event, QueueAgentThoughtEvent): | |||
| agent_thought = ( | |||
| agent_thought: MessageAgentThought = ( | |||
| db.session.query(MessageAgentThought) | |||
| .filter(MessageAgentThought.id == event.agent_thought_id) | |||
| .first() | |||
| @@ -298,6 +299,7 @@ class GenerateTaskPipeline: | |||
| 'thought': agent_thought.thought, | |||
| 'observation': agent_thought.observation, | |||
| 'tool': agent_thought.tool, | |||
| 'tool_labels': agent_thought.tool_labels, | |||
| 'tool_input': agent_thought.tool_input, | |||
| 'created_at': int(self._message.created_at.timestamp()), | |||
| 'message_files': agent_thought.files | |||
| @@ -396,6 +396,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| message_chain_id=None, | |||
| thought='', | |||
| tool=tool_name, | |||
| tool_labels_str='{}', | |||
| tool_input=tool_input, | |||
| message=message, | |||
| message_token=0, | |||
| @@ -469,6 +470,21 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| agent_thought.tokens = llm_usage.total_tokens | |||
| agent_thought.total_price = llm_usage.total_price | |||
| # check if tool labels is not empty | |||
| labels = agent_thought.tool_labels or {} | |||
| tools = agent_thought.tool.split(';') if agent_thought.tool else [] | |||
| for tool in tools: | |||
| if not tool: | |||
| continue | |||
| if tool not in labels: | |||
| tool_label = ToolManager.get_tool_label(tool) | |||
| if tool_label: | |||
| labels[tool] = tool_label.to_dict() | |||
| else: | |||
| labels[tool] = {'en_US': tool, 'zh_Hans': tool} | |||
| agent_thought.tool_labels_str = json.dumps(labels) | |||
| db.session.commit() | |||
| def get_history_prompt_messages(self) -> List[PromptMessage]: | |||
| @@ -31,6 +31,7 @@ import mimetypes | |||
| logger = logging.getLogger(__name__) | |||
| _builtin_providers = {} | |||
| _builtin_tools_labels = {} | |||
| class ToolManager: | |||
| @staticmethod | |||
| @@ -233,7 +234,7 @@ class ToolManager: | |||
| if len(_builtin_providers) > 0: | |||
| return list(_builtin_providers.values()) | |||
| builtin_providers = [] | |||
| builtin_providers: List[BuiltinToolProviderController] = [] | |||
| for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): | |||
| if provider.startswith('__'): | |||
| continue | |||
| @@ -264,8 +265,30 @@ class ToolManager: | |||
| # cache the builtin providers | |||
| for provider in builtin_providers: | |||
| _builtin_providers[provider.identity.name] = provider | |||
| for tool in provider.get_tools(): | |||
| _builtin_tools_labels[tool.identity.name] = tool.identity.label | |||
| return builtin_providers | |||
| @staticmethod | |||
| def get_tool_label(tool_name: str) -> Union[I18nObject, None]: | |||
| """ | |||
| get the tool label | |||
| :param tool_name: the name of the tool | |||
| :return: the label of the tool | |||
| """ | |||
| global _builtin_tools_labels | |||
| if len(_builtin_tools_labels) == 0: | |||
| # init the builtin providers | |||
| ToolManager.list_builtin_providers() | |||
| if tool_name not in _builtin_tools_labels: | |||
| return None | |||
| return _builtin_tools_labels[tool_name] | |||
| @staticmethod | |||
| def user_list_providers( | |||
| user_id: str, | |||
| @@ -49,10 +49,11 @@ agent_thought_fields = { | |||
| 'position': fields.Integer, | |||
| 'thought': fields.String, | |||
| 'tool': fields.String, | |||
| 'tool_labels': fields.Raw, | |||
| 'tool_input': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'observation': fields.String, | |||
| 'files': fields.List(fields.String) | |||
| 'files': fields.List(fields.String), | |||
| } | |||
| message_detail_fields = { | |||
| @@ -36,6 +36,7 @@ agent_thought_fields = { | |||
| 'position': fields.Integer, | |||
| 'thought': fields.String, | |||
| 'tool': fields.String, | |||
| 'tool_labels': fields.Raw, | |||
| 'tool_input': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'observation': fields.String, | |||
| @@ -0,0 +1,32 @@ | |||
| """add tool labels to agent thought | |||
| Revision ID: 380c6aa5a70d | |||
| Revises: dfb3b7f477da | |||
| Create Date: 2024-01-24 10:58:15.644445 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| # revision identifiers, used by Alembic. | |||
| revision = '380c6aa5a70d' | |||
| down_revision = 'dfb3b7f477da' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('tool_labels_str', sa.Text(), server_default=sa.text("'{}'::text"), nullable=False)) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: | |||
| batch_op.drop_column('tool_labels_str') | |||
| # ### end Alembic commands ### | |||
| @@ -1003,6 +1003,7 @@ class MessageAgentThought(db.Model): | |||
| position = db.Column(db.Integer, nullable=False) | |||
| thought = db.Column(db.Text, nullable=True) | |||
| tool = db.Column(db.Text, nullable=True) | |||
| tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) | |||
| tool_input = db.Column(db.Text, nullable=True) | |||
| observation = db.Column(db.Text, nullable=True) | |||
| # plugin_id = db.Column(UUID, nullable=True) ## for future design | |||
| @@ -1030,6 +1031,16 @@ class MessageAgentThought(db.Model): | |||
| return json.loads(self.message_files) | |||
| else: | |||
| return [] | |||
| @property | |||
| def tool_labels(self) -> dict: | |||
| try: | |||
| if self.tool_labels_str: | |||
| return json.loads(self.tool_labels_str) | |||
| else: | |||
| return {} | |||
| except Exception as e: | |||
| return {} | |||
| class DatasetRetrieverResource(db.Model): | |||
| __tablename__ = 'dataset_retriever_resources' | |||