|
|
|
@@ -1,4 +1,6 @@ |
|
|
|
import re |
|
|
|
from typing import Any |
|
|
|
from urllib.parse import urlparse |
|
|
|
|
|
|
|
from core.tools.errors import ToolProviderCredentialValidationError |
|
|
|
from core.tools.provider.builtin.vanna.tools.vanna import VannaTool |
|
|
|
@@ -6,7 +8,26 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl |
|
|
|
|
|
|
|
|
|
|
|
class VannaProvider(BuiltinToolProviderController): |
|
|
|
def _get_protocol_and_main_domain(self, url): |
|
|
|
parsed_url = urlparse(url) |
|
|
|
protocol = parsed_url.scheme |
|
|
|
hostname = parsed_url.hostname |
|
|
|
port = f":{parsed_url.port}" if parsed_url.port else "" |
|
|
|
|
|
|
|
# Check if the hostname is an IP address |
|
|
|
is_ip = re.match(r"^\d{1,3}(\.\d{1,3}){3}$", hostname) is not None |
|
|
|
|
|
|
|
# Return the full hostname (with port if present) for IP addresses, otherwise return the main domain |
|
|
|
main_domain = f"{hostname}{port}" if is_ip else ".".join(hostname.split(".")[-2:]) + port |
|
|
|
return f"{protocol}://{main_domain}" |
|
|
|
|
|
|
|
def _validate_credentials(self, credentials: dict[str, Any]) -> None: |
|
|
|
base_url = credentials.get("base_url") |
|
|
|
if not base_url: |
|
|
|
base_url = "https://ask.vanna.ai/rpc" |
|
|
|
else: |
|
|
|
base_url = base_url.removesuffix("/") |
|
|
|
credentials["base_url"] = base_url |
|
|
|
try: |
|
|
|
VannaTool().fork_tool_runtime( |
|
|
|
runtime={ |
|
|
|
@@ -17,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController): |
|
|
|
tool_parameters={ |
|
|
|
"model": "chinook", |
|
|
|
"db_type": "SQLite", |
|
|
|
"url": "https://vanna.ai/Chinook.sqlite", |
|
|
|
"url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite', |
|
|
|
"query": "What are the top 10 customers by sales?", |
|
|
|
}, |
|
|
|
) |