| 'position': fields.Integer, | 'position': fields.Integer, | ||||
| 'thought': fields.String, | 'thought': fields.String, | ||||
| 'tool': fields.String, | 'tool': fields.String, | ||||
| 'tool_labels': fields.Raw, | |||||
| 'tool_input': fields.String, | 'tool_input': fields.String, | ||||
| 'created_at': TimestampField, | 'created_at': TimestampField, | ||||
| 'observation': fields.String, | 'observation': fields.String, |
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.tools.tool_file_manager import ToolFileManager | from core.tools.tool_file_manager import ToolFileManager | ||||
| from core.tools.tool_manager import ToolManager | |||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from core.prompt.prompt_template import PromptTemplateParser | from core.prompt.prompt_template import PromptTemplateParser | ||||
| from events.message_event import message_was_created | from events.message_event import message_was_created | ||||
| self._task_state.llm_result.message.content = annotation.content | self._task_state.llm_result.message.content = annotation.content | ||||
| elif isinstance(event, QueueAgentThoughtEvent): | elif isinstance(event, QueueAgentThoughtEvent): | ||||
| agent_thought = ( | |||||
| agent_thought: MessageAgentThought = ( | |||||
| db.session.query(MessageAgentThought) | db.session.query(MessageAgentThought) | ||||
| .filter(MessageAgentThought.id == event.agent_thought_id) | .filter(MessageAgentThought.id == event.agent_thought_id) | ||||
| .first() | .first() | ||||
| 'thought': agent_thought.thought, | 'thought': agent_thought.thought, | ||||
| 'observation': agent_thought.observation, | 'observation': agent_thought.observation, | ||||
| 'tool': agent_thought.tool, | 'tool': agent_thought.tool, | ||||
| 'tool_labels': agent_thought.tool_labels, | |||||
| 'tool_input': agent_thought.tool_input, | 'tool_input': agent_thought.tool_input, | ||||
| 'created_at': int(self._message.created_at.timestamp()), | 'created_at': int(self._message.created_at.timestamp()), | ||||
| 'message_files': agent_thought.files | 'message_files': agent_thought.files |
| message_chain_id=None, | message_chain_id=None, | ||||
| thought='', | thought='', | ||||
| tool=tool_name, | tool=tool_name, | ||||
| tool_labels_str='{}', | |||||
| tool_input=tool_input, | tool_input=tool_input, | ||||
| message=message, | message=message, | ||||
| message_token=0, | message_token=0, | ||||
| agent_thought.tokens = llm_usage.total_tokens | agent_thought.tokens = llm_usage.total_tokens | ||||
| agent_thought.total_price = llm_usage.total_price | agent_thought.total_price = llm_usage.total_price | ||||
| # check if tool labels is not empty | |||||
| labels = agent_thought.tool_labels or {} | |||||
| tools = agent_thought.tool.split(';') if agent_thought.tool else [] | |||||
| for tool in tools: | |||||
| if not tool: | |||||
| continue | |||||
| if tool not in labels: | |||||
| tool_label = ToolManager.get_tool_label(tool) | |||||
| if tool_label: | |||||
| labels[tool] = tool_label.to_dict() | |||||
| else: | |||||
| labels[tool] = {'en_US': tool, 'zh_Hans': tool} | |||||
| agent_thought.tool_labels_str = json.dumps(labels) | |||||
| db.session.commit() | db.session.commit() | ||||
| def get_history_prompt_messages(self) -> List[PromptMessage]: | def get_history_prompt_messages(self) -> List[PromptMessage]: |
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| _builtin_providers = {} | _builtin_providers = {} | ||||
| _builtin_tools_labels = {} | |||||
| class ToolManager: | class ToolManager: | ||||
| @staticmethod | @staticmethod | ||||
| if len(_builtin_providers) > 0: | if len(_builtin_providers) > 0: | ||||
| return list(_builtin_providers.values()) | return list(_builtin_providers.values()) | ||||
| builtin_providers = [] | |||||
| builtin_providers: List[BuiltinToolProviderController] = [] | |||||
| for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): | for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): | ||||
| if provider.startswith('__'): | if provider.startswith('__'): | ||||
| continue | continue | ||||
| # cache the builtin providers | # cache the builtin providers | ||||
| for provider in builtin_providers: | for provider in builtin_providers: | ||||
| _builtin_providers[provider.identity.name] = provider | _builtin_providers[provider.identity.name] = provider | ||||
| for tool in provider.get_tools(): | |||||
| _builtin_tools_labels[tool.identity.name] = tool.identity.label | |||||
| return builtin_providers | return builtin_providers | ||||
| @staticmethod | |||||
| def get_tool_label(tool_name: str) -> Union[I18nObject, None]: | |||||
| """ | |||||
| get the tool label | |||||
| :param tool_name: the name of the tool | |||||
| :return: the label of the tool | |||||
| """ | |||||
| global _builtin_tools_labels | |||||
| if len(_builtin_tools_labels) == 0: | |||||
| # init the builtin providers | |||||
| ToolManager.list_builtin_providers() | |||||
| if tool_name not in _builtin_tools_labels: | |||||
| return None | |||||
| return _builtin_tools_labels[tool_name] | |||||
| @staticmethod | @staticmethod | ||||
| def user_list_providers( | def user_list_providers( | ||||
| user_id: str, | user_id: str, |
| 'position': fields.Integer, | 'position': fields.Integer, | ||||
| 'thought': fields.String, | 'thought': fields.String, | ||||
| 'tool': fields.String, | 'tool': fields.String, | ||||
| 'tool_labels': fields.Raw, | |||||
| 'tool_input': fields.String, | 'tool_input': fields.String, | ||||
| 'created_at': TimestampField, | 'created_at': TimestampField, | ||||
| 'observation': fields.String, | 'observation': fields.String, | ||||
| 'files': fields.List(fields.String) | |||||
| 'files': fields.List(fields.String), | |||||
| } | } | ||||
| message_detail_fields = { | message_detail_fields = { |
| 'position': fields.Integer, | 'position': fields.Integer, | ||||
| 'thought': fields.String, | 'thought': fields.String, | ||||
| 'tool': fields.String, | 'tool': fields.String, | ||||
| 'tool_labels': fields.Raw, | |||||
| 'tool_input': fields.String, | 'tool_input': fields.String, | ||||
| 'created_at': TimestampField, | 'created_at': TimestampField, | ||||
| 'observation': fields.String, | 'observation': fields.String, |
| """add tool labels to agent thought | |||||
| Revision ID: 380c6aa5a70d | |||||
| Revises: dfb3b7f477da | |||||
| Create Date: 2024-01-24 10:58:15.644445 | |||||
| """ | |||||
| from alembic import op | |||||
| import sqlalchemy as sa | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = '380c6aa5a70d' | |||||
| down_revision = 'dfb3b7f477da' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('tool_labels_str', sa.Text(), server_default=sa.text("'{}'::text"), nullable=False)) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: | |||||
| batch_op.drop_column('tool_labels_str') | |||||
| # ### end Alembic commands ### |
| position = db.Column(db.Integer, nullable=False) | position = db.Column(db.Integer, nullable=False) | ||||
| thought = db.Column(db.Text, nullable=True) | thought = db.Column(db.Text, nullable=True) | ||||
| tool = db.Column(db.Text, nullable=True) | tool = db.Column(db.Text, nullable=True) | ||||
| tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) | |||||
| tool_input = db.Column(db.Text, nullable=True) | tool_input = db.Column(db.Text, nullable=True) | ||||
| observation = db.Column(db.Text, nullable=True) | observation = db.Column(db.Text, nullable=True) | ||||
| # plugin_id = db.Column(UUID, nullable=True) ## for future design | # plugin_id = db.Column(UUID, nullable=True) ## for future design | ||||
| return json.loads(self.message_files) | return json.loads(self.message_files) | ||||
| else: | else: | ||||
| return [] | return [] | ||||
| @property | |||||
| def tool_labels(self) -> dict: | |||||
| try: | |||||
| if self.tool_labels_str: | |||||
| return json.loads(self.tool_labels_str) | |||||
| else: | |||||
| return {} | |||||
| except Exception as e: | |||||
| return {} | |||||
| class DatasetRetrieverResource(db.Model): | class DatasetRetrieverResource(db.Model): | ||||
| __tablename__ = 'dataset_retriever_resources' | __tablename__ = 'dataset_retriever_resources' |