Signed-off-by: -LAN- <laipz8200@outlook.com>tags/0.15.5
| @@ -77,5 +77,4 @@ | |||
| - onebot | |||
| - regex | |||
| - trello | |||
| - vanna | |||
| - fal | |||
| @@ -1,134 +0,0 @@ | |||
| from typing import Any, Union | |||
| from vanna.remote import VannaDefault # type: ignore | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.errors import ToolProviderCredentialValidationError | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| class VannaTool(BuiltinTool): | |||
| def _invoke( | |||
| self, user_id: str, tool_parameters: dict[str, Any] | |||
| ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | |||
| """ | |||
| invoke tools | |||
| """ | |||
| # Ensure runtime and credentials | |||
| if not self.runtime or not self.runtime.credentials: | |||
| raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing") | |||
| api_key = self.runtime.credentials.get("api_key", None) | |||
| if not api_key: | |||
| raise ToolProviderCredentialValidationError("Please input api key") | |||
| model = tool_parameters.get("model", "") | |||
| if not model: | |||
| return self.create_text_message("Please input RAG model") | |||
| prompt = tool_parameters.get("prompt", "") | |||
| if not prompt: | |||
| return self.create_text_message("Please input prompt") | |||
| url = tool_parameters.get("url", "") | |||
| if not url: | |||
| return self.create_text_message("Please input URL/Host/DSN") | |||
| db_name = tool_parameters.get("db_name", "") | |||
| username = tool_parameters.get("username", "") | |||
| password = tool_parameters.get("password", "") | |||
| port = tool_parameters.get("port", 0) | |||
| base_url = self.runtime.credentials.get("base_url", None) | |||
| vn = VannaDefault(model=model, api_key=api_key, config={"endpoint": base_url}) | |||
| db_type = tool_parameters.get("db_type", "") | |||
| if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}: | |||
| if not db_name: | |||
| return self.create_text_message("Please input database name") | |||
| if not username: | |||
| return self.create_text_message("Please input username") | |||
| if port < 1: | |||
| return self.create_text_message("Please input port") | |||
| schema_sql = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS" | |||
| match db_type: | |||
| case "SQLite": | |||
| schema_sql = "SELECT type, sql FROM sqlite_master WHERE sql is not null" | |||
| vn.connect_to_sqlite(url) | |||
| case "Postgres": | |||
| vn.connect_to_postgres(host=url, dbname=db_name, user=username, password=password, port=port) | |||
| case "DuckDB": | |||
| vn.connect_to_duckdb(url=url) | |||
| case "SQLServer": | |||
| vn.connect_to_mssql(url) | |||
| case "MySQL": | |||
| vn.connect_to_mysql(host=url, dbname=db_name, user=username, password=password, port=port) | |||
| case "Oracle": | |||
| vn.connect_to_oracle(user=username, password=password, dsn=url) | |||
| case "Hive": | |||
| vn.connect_to_hive(host=url, dbname=db_name, user=username, password=password, port=port) | |||
| case "ClickHouse": | |||
| vn.connect_to_clickhouse(host=url, dbname=db_name, user=username, password=password, port=port) | |||
| enable_training = tool_parameters.get("enable_training", False) | |||
| reset_training_data = tool_parameters.get("reset_training_data", False) | |||
| if enable_training: | |||
| if reset_training_data: | |||
| existing_training_data = vn.get_training_data() | |||
| if len(existing_training_data) > 0: | |||
| for _, training_data in existing_training_data.iterrows(): | |||
| vn.remove_training_data(training_data["id"]) | |||
| ddl = tool_parameters.get("ddl", "") | |||
| question = tool_parameters.get("question", "") | |||
| sql = tool_parameters.get("sql", "") | |||
| memos = tool_parameters.get("memos", "") | |||
| training_metadata = tool_parameters.get("training_metadata", False) | |||
| if training_metadata: | |||
| if db_type == "SQLite": | |||
| df_ddl = vn.run_sql(schema_sql) | |||
| for ddl in df_ddl["sql"].to_list(): | |||
| vn.train(ddl=ddl) | |||
| else: | |||
| df_information_schema = vn.run_sql(schema_sql) | |||
| plan = vn.get_training_plan_generic(df_information_schema) | |||
| vn.train(plan=plan) | |||
| if ddl: | |||
| vn.train(ddl=ddl) | |||
| if sql: | |||
| if question: | |||
| vn.train(question=question, sql=sql) | |||
| else: | |||
| vn.train(sql=sql) | |||
| if memos: | |||
| vn.train(documentation=memos) | |||
| ######################################################################################### | |||
| # Due to CVE-2024-5565, we have to disable the chart generation feature | |||
| # The Vanna library uses a prompt function to present the user with visualized results, | |||
| # it is possible to alter the prompt using prompt injection and run arbitrary Python code | |||
| # instead of the intended visualization code. | |||
| # Specifically - allowing external input to the library’s “ask” method | |||
| # with "visualize" set to True (default behavior) leads to remote code execution. | |||
| # Affected versions: <= 0.5.5 | |||
| ######################################################################################### | |||
| allow_llm_to_see_data = tool_parameters.get("allow_llm_to_see_data", False) | |||
| res = vn.ask( | |||
| prompt, print_results=False, auto_train=True, visualize=False, allow_llm_to_see_data=allow_llm_to_see_data | |||
| ) | |||
| result = [] | |||
| if res is not None: | |||
| result.append(self.create_text_message(res[0])) | |||
| if len(res) > 1 and res[1] is not None: | |||
| result.append(self.create_text_message(res[1].to_markdown())) | |||
| if len(res) > 2 and res[2] is not None: | |||
| result.append( | |||
| self.create_blob_message(blob=res[2].to_image(format="svg"), meta={"mime_type": "image/svg+xml"}) | |||
| ) | |||
| return result | |||
| @@ -1,213 +0,0 @@ | |||
| identity: | |||
| name: vanna | |||
| author: QCTC | |||
| label: | |||
| en_US: Vanna.AI | |||
| zh_Hans: Vanna.AI | |||
| description: | |||
| human: | |||
| en_US: The fastest way to get actionable insights from your database just by asking questions. | |||
| zh_Hans: 一个基于大模型和RAG的Text2SQL工具。 | |||
| llm: A tool for converting text to SQL. | |||
| parameters: | |||
| - name: prompt | |||
| type: string | |||
| required: true | |||
| label: | |||
| en_US: Prompt | |||
| zh_Hans: 提示词 | |||
| pt_BR: Prompt | |||
| human_description: | |||
| en_US: used for generating SQL | |||
| zh_Hans: 用于生成SQL | |||
| llm_description: key words for generating SQL | |||
| form: llm | |||
| - name: model | |||
| type: string | |||
| required: true | |||
| label: | |||
| en_US: RAG Model | |||
| zh_Hans: RAG模型 | |||
| human_description: | |||
| en_US: RAG Model for your database DDL | |||
| zh_Hans: 存储数据库训练数据的RAG模型 | |||
| llm_description: RAG Model for generating SQL | |||
| form: llm | |||
| - name: db_type | |||
| type: select | |||
| required: true | |||
| options: | |||
| - value: SQLite | |||
| label: | |||
| en_US: SQLite | |||
| zh_Hans: SQLite | |||
| - value: Postgres | |||
| label: | |||
| en_US: Postgres | |||
| zh_Hans: Postgres | |||
| - value: DuckDB | |||
| label: | |||
| en_US: DuckDB | |||
| zh_Hans: DuckDB | |||
| - value: SQLServer | |||
| label: | |||
| en_US: Microsoft SQL Server | |||
| zh_Hans: 微软 SQL Server | |||
| - value: MySQL | |||
| label: | |||
| en_US: MySQL | |||
| zh_Hans: MySQL | |||
| - value: Oracle | |||
| label: | |||
| en_US: Oracle | |||
| zh_Hans: Oracle | |||
| - value: Hive | |||
| label: | |||
| en_US: Hive | |||
| zh_Hans: Hive | |||
| - value: ClickHouse | |||
| label: | |||
| en_US: ClickHouse | |||
| zh_Hans: ClickHouse | |||
| default: SQLite | |||
| label: | |||
| en_US: DB Type | |||
| zh_Hans: 数据库类型 | |||
| human_description: | |||
| en_US: Database type. | |||
| zh_Hans: 选择要链接的数据库类型。 | |||
| form: form | |||
| - name: url | |||
| type: string | |||
| required: true | |||
| label: | |||
| en_US: URL/Host/DSN | |||
| zh_Hans: URL/Host/DSN | |||
| human_description: | |||
| en_US: Please input depending on DB type, visit https://vanna.ai/docs/ for more specification | |||
| zh_Hans: 请根据数据库类型,填入对应值,详情参考https://vanna.ai/docs/ | |||
| form: form | |||
| - name: db_name | |||
| type: string | |||
| required: false | |||
| label: | |||
| en_US: DB name | |||
| zh_Hans: 数据库名 | |||
| human_description: | |||
| en_US: Database name | |||
| zh_Hans: 数据库名 | |||
| form: form | |||
| - name: username | |||
| type: string | |||
| required: false | |||
| label: | |||
| en_US: Username | |||
| zh_Hans: 用户名 | |||
| human_description: | |||
| en_US: Username | |||
| zh_Hans: 用户名 | |||
| form: form | |||
| - name: password | |||
| type: secret-input | |||
| required: false | |||
| label: | |||
| en_US: Password | |||
| zh_Hans: 密码 | |||
| human_description: | |||
| en_US: Password | |||
| zh_Hans: 密码 | |||
| form: form | |||
| - name: port | |||
| type: number | |||
| required: false | |||
| label: | |||
| en_US: Port | |||
| zh_Hans: 端口 | |||
| human_description: | |||
| en_US: Port | |||
| zh_Hans: 端口 | |||
| form: form | |||
| - name: ddl | |||
| type: string | |||
| required: false | |||
| label: | |||
| en_US: Training DDL | |||
| zh_Hans: 训练DDL | |||
| human_description: | |||
| en_US: DDL statements for training data | |||
| zh_Hans: 用于训练RAG Model的建表语句 | |||
| form: llm | |||
| - name: question | |||
| type: string | |||
| required: false | |||
| label: | |||
| en_US: Training Question | |||
| zh_Hans: 训练问题 | |||
| human_description: | |||
| en_US: Question-SQL Pairs | |||
| zh_Hans: Question-SQL中的问题 | |||
| form: llm | |||
| - name: sql | |||
| type: string | |||
| required: false | |||
| label: | |||
| en_US: Training SQL | |||
| zh_Hans: 训练SQL | |||
| human_description: | |||
| en_US: SQL queries to your training data | |||
| zh_Hans: 用于训练RAG Model的SQL语句 | |||
| form: llm | |||
| - name: memos | |||
| type: string | |||
| required: false | |||
| label: | |||
| en_US: Training Memos | |||
| zh_Hans: 训练说明 | |||
| human_description: | |||
| en_US: Sometimes you may want to add documentation about your business terminology or definitions | |||
| zh_Hans: 添加更多关于数据库的业务说明 | |||
| form: llm | |||
| - name: enable_training | |||
| type: boolean | |||
| required: false | |||
| default: false | |||
| label: | |||
| en_US: Training Data | |||
| zh_Hans: 训练数据 | |||
| human_description: | |||
| en_US: You only need to train once. Do not train again unless you want to add more training data | |||
| zh_Hans: 训练数据无更新时,训练一次即可 | |||
| form: form | |||
| - name: reset_training_data | |||
| type: boolean | |||
| required: false | |||
| default: false | |||
| label: | |||
| en_US: Reset Training Data | |||
| zh_Hans: 重置训练数据 | |||
| human_description: | |||
| en_US: Remove all training data in the current RAG Model | |||
| zh_Hans: 删除当前RAG Model中的所有训练数据 | |||
| form: form | |||
| - name: training_metadata | |||
| type: boolean | |||
| required: false | |||
| default: false | |||
| label: | |||
| en_US: Training Metadata | |||
| zh_Hans: 训练元数据 | |||
| human_description: | |||
| en_US: If enabled, it will attempt to train on the metadata of that database | |||
| zh_Hans: 是否自动从数据库获取元数据来训练 | |||
| form: form | |||
| - name: allow_llm_to_see_data | |||
| type: boolean | |||
| required: false | |||
| default: false | |||
| label: | |||
| en_US: Whether to allow the LLM to see the data | |||
| zh_Hans: 是否允许LLM查看数据 | |||
| human_description: | |||
| en_US: Whether to allow the LLM to see the data | |||
| zh_Hans: 是否允许LLM查看数据 | |||
| form: form | |||
| @@ -1,46 +0,0 @@ | |||
| import re | |||
| from typing import Any | |||
| from urllib.parse import urlparse | |||
| from core.tools.errors import ToolProviderCredentialValidationError | |||
| from core.tools.provider.builtin.vanna.tools.vanna import VannaTool | |||
| from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController | |||
| class VannaProvider(BuiltinToolProviderController): | |||
| def _get_protocol_and_main_domain(self, url): | |||
| parsed_url = urlparse(url) | |||
| protocol = parsed_url.scheme | |||
| hostname = parsed_url.hostname | |||
| port = f":{parsed_url.port}" if parsed_url.port else "" | |||
| # Check if the hostname is an IP address | |||
| is_ip = re.match(r"^\d{1,3}(\.\d{1,3}){3}$", hostname) is not None | |||
| # Return the full hostname (with port if present) for IP addresses, otherwise return the main domain | |||
| main_domain = f"{hostname}{port}" if is_ip else ".".join(hostname.split(".")[-2:]) + port | |||
| return f"{protocol}://{main_domain}" | |||
| def _validate_credentials(self, credentials: dict[str, Any]) -> None: | |||
| base_url = credentials.get("base_url") | |||
| if not base_url: | |||
| base_url = "https://ask.vanna.ai/rpc" | |||
| else: | |||
| base_url = base_url.removesuffix("/") | |||
| credentials["base_url"] = base_url | |||
| try: | |||
| VannaTool().fork_tool_runtime( | |||
| runtime={ | |||
| "credentials": credentials, | |||
| } | |||
| ).invoke( | |||
| user_id="", | |||
| tool_parameters={ | |||
| "model": "chinook", | |||
| "db_type": "SQLite", | |||
| "url": f"{self._get_protocol_and_main_domain(credentials['base_url'])}/Chinook.sqlite", | |||
| "query": "What are the top 10 customers by sales?", | |||
| }, | |||
| ) | |||
| except Exception as e: | |||
| raise ToolProviderCredentialValidationError(str(e)) | |||
| @@ -1,35 +0,0 @@ | |||
| identity: | |||
| author: QCTC | |||
| name: vanna | |||
| label: | |||
| en_US: Vanna.AI | |||
| zh_Hans: Vanna.AI | |||
| description: | |||
| en_US: The fastest way to get actionable insights from your database just by asking questions. | |||
| zh_Hans: 一个基于大模型和RAG的Text2SQL工具。 | |||
| icon: icon.png | |||
| tags: | |||
| - utilities | |||
| - productivity | |||
| credentials_for_provider: | |||
| api_key: | |||
| type: secret-input | |||
| required: true | |||
| label: | |||
| en_US: API key | |||
| zh_Hans: API key | |||
| placeholder: | |||
| en_US: Please input your API key | |||
| zh_Hans: 请输入你的 API key | |||
| pt_BR: Please input your API key | |||
| help: | |||
| en_US: Get your API key from Vanna.AI | |||
| zh_Hans: 从 Vanna.AI 获取你的 API key | |||
| url: https://vanna.ai/account/profile | |||
| base_url: | |||
| type: text-input | |||
| required: false | |||
| label: | |||
| en_US: Vanna.AI Endpoint Base URL | |||
| placeholder: | |||
| en_US: https://ask.vanna.ai/rpc | |||