您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

tool_manager.py 41KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024
  1. import json
  2. import logging
  3. import mimetypes
  4. import time
  5. from collections.abc import Generator, Mapping
  6. from os import listdir, path
  7. from threading import Lock
  8. from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
  9. import sqlalchemy as sa
  10. from pydantic import TypeAdapter
  11. from sqlalchemy import select
  12. from sqlalchemy.orm import Session
  13. from yarl import URL
  14. import contexts
  15. from core.helper.provider_cache import ToolProviderCredentialsCache
  16. from core.plugin.entities.plugin import ToolProviderID
  17. from core.plugin.impl.oauth import OAuthHandler
  18. from core.plugin.impl.tool import PluginToolManager
  19. from core.tools.__base.tool_provider import ToolProviderController
  20. from core.tools.__base.tool_runtime import ToolRuntime
  21. from core.tools.mcp_tool.provider import MCPToolProviderController
  22. from core.tools.mcp_tool.tool import MCPTool
  23. from core.tools.plugin_tool.provider import PluginToolProviderController
  24. from core.tools.plugin_tool.tool import PluginTool
  25. from core.tools.utils.uuid_utils import is_valid_uuid
  26. from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
  27. from core.workflow.entities.variable_pool import VariablePool
  28. from services.tools.mcp_tools_manage_service import MCPToolManageService
  29. if TYPE_CHECKING:
  30. from core.workflow.nodes.tool.entities import ToolEntity
  31. from configs import dify_config
  32. from core.agent.entities import AgentToolEntity
  33. from core.app.entities.app_invoke_entities import InvokeFrom
  34. from core.helper.module_import_helper import load_single_subclass_from_source
  35. from core.helper.position_helper import is_filtered
  36. from core.model_runtime.utils.encoders import jsonable_encoder
  37. from core.tools.__base.tool import Tool
  38. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  39. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  40. from core.tools.builtin_tool.tool import BuiltinTool
  41. from core.tools.custom_tool.provider import ApiToolProviderController
  42. from core.tools.custom_tool.tool import ApiTool
  43. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  44. from core.tools.entities.common_entities import I18nObject
  45. from core.tools.entities.tool_entities import (
  46. ApiProviderAuthType,
  47. CredentialType,
  48. ToolInvokeFrom,
  49. ToolParameter,
  50. ToolProviderType,
  51. )
  52. from core.tools.errors import ToolProviderNotFoundError
  53. from core.tools.tool_label_manager import ToolLabelManager
  54. from core.tools.utils.configuration import (
  55. ToolParameterConfigurationManager,
  56. )
  57. from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
  58. from core.tools.workflow_as_tool.tool import WorkflowTool
  59. from extensions.ext_database import db
  60. from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
  61. from services.tools.tools_transform_service import ToolTransformService
  62. logger = logging.getLogger(__name__)
  63. class ToolManager:
  64. _builtin_provider_lock = Lock()
  65. _hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
  66. _builtin_providers_loaded = False
  67. _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
  68. @classmethod
  69. def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
  70. """
  71. get the hardcoded provider
  72. """
  73. if len(cls._hardcoded_providers) == 0:
  74. # init the builtin providers
  75. cls.load_hardcoded_providers_cache()
  76. return cls._hardcoded_providers[provider]
  77. @classmethod
  78. def get_builtin_provider(
  79. cls, provider: str, tenant_id: str
  80. ) -> BuiltinToolProviderController | PluginToolProviderController:
  81. """
  82. get the builtin provider
  83. :param provider: the name of the provider
  84. :param tenant_id: the id of the tenant
  85. :return: the provider
  86. """
  87. # split provider to
  88. if len(cls._hardcoded_providers) == 0:
  89. # init the builtin providers
  90. cls.load_hardcoded_providers_cache()
  91. if provider not in cls._hardcoded_providers:
  92. # get plugin provider
  93. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  94. if plugin_provider:
  95. return plugin_provider
  96. return cls._hardcoded_providers[provider]
  97. @classmethod
  98. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  99. """
  100. get the plugin provider
  101. """
  102. # check if context is set
  103. try:
  104. contexts.plugin_tool_providers.get()
  105. except LookupError:
  106. contexts.plugin_tool_providers.set({})
  107. contexts.plugin_tool_providers_lock.set(Lock())
  108. plugin_tool_providers = contexts.plugin_tool_providers.get()
  109. if provider in plugin_tool_providers:
  110. return plugin_tool_providers[provider]
  111. with contexts.plugin_tool_providers_lock.get():
  112. # double check
  113. plugin_tool_providers = contexts.plugin_tool_providers.get()
  114. if provider in plugin_tool_providers:
  115. return plugin_tool_providers[provider]
  116. manager = PluginToolManager()
  117. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  118. if not provider_entity:
  119. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  120. controller = PluginToolProviderController(
  121. entity=provider_entity.declaration,
  122. plugin_id=provider_entity.plugin_id,
  123. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  124. tenant_id=tenant_id,
  125. )
  126. plugin_tool_providers[provider] = controller
  127. return controller
  128. @classmethod
  129. def get_tool_runtime(
  130. cls,
  131. provider_type: ToolProviderType,
  132. provider_id: str,
  133. tool_name: str,
  134. tenant_id: str,
  135. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  136. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  137. credential_id: Optional[str] = None,
  138. ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
  139. """
  140. get the tool runtime
  141. :param provider_type: the type of the provider
  142. :param provider_id: the id of the provider
  143. :param tool_name: the name of the tool
  144. :param tenant_id: the tenant id
  145. :param invoke_from: invoke from
  146. :param tool_invoke_from: the tool invoke from
  147. :param credential_id: the credential id
  148. :return: the tool
  149. """
  150. if provider_type == ToolProviderType.BUILT_IN:
  151. # check if the builtin tool need credentials
  152. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  153. builtin_tool = provider_controller.get_tool(tool_name)
  154. if not builtin_tool:
  155. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  156. if not provider_controller.need_credentials:
  157. return cast(
  158. BuiltinTool,
  159. builtin_tool.fork_tool_runtime(
  160. runtime=ToolRuntime(
  161. tenant_id=tenant_id,
  162. credentials={},
  163. invoke_from=invoke_from,
  164. tool_invoke_from=tool_invoke_from,
  165. )
  166. ),
  167. )
  168. builtin_provider = None
  169. if isinstance(provider_controller, PluginToolProviderController):
  170. provider_id_entity = ToolProviderID(provider_id)
  171. # get specific credentials
  172. if is_valid_uuid(credential_id):
  173. try:
  174. builtin_provider_stmt = select(BuiltinToolProvider).where(
  175. BuiltinToolProvider.tenant_id == tenant_id,
  176. BuiltinToolProvider.id == credential_id,
  177. )
  178. builtin_provider = db.session.scalar(builtin_provider_stmt)
  179. except Exception as e:
  180. builtin_provider = None
  181. logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
  182. # if the provider has been deleted, raise an error
  183. if builtin_provider is None:
  184. raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
  185. # fallback to the default provider
  186. if builtin_provider is None:
  187. # use the default provider
  188. builtin_provider = (
  189. db.session.query(BuiltinToolProvider)
  190. .where(
  191. BuiltinToolProvider.tenant_id == tenant_id,
  192. (BuiltinToolProvider.provider == str(provider_id_entity))
  193. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  194. )
  195. .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
  196. .first()
  197. )
  198. if builtin_provider is None:
  199. raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
  200. else:
  201. builtin_provider = (
  202. db.session.query(BuiltinToolProvider)
  203. .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  204. .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
  205. .first()
  206. )
  207. if builtin_provider is None:
  208. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  209. encrypter, cache = create_provider_encrypter(
  210. tenant_id=tenant_id,
  211. config=[
  212. x.to_basic_provider_config()
  213. for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
  214. ],
  215. cache=ToolProviderCredentialsCache(
  216. tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
  217. ),
  218. )
  219. # decrypt the credentials
  220. decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
  221. # check if the credentials is expired
  222. if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
  223. # TODO: circular import
  224. from services.tools.builtin_tools_manage_service import BuiltinToolManageService
  225. # refresh the credentials
  226. tool_provider = ToolProviderID(provider_id)
  227. provider_name = tool_provider.provider_name
  228. redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
  229. system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
  230. oauth_handler = OAuthHandler()
  231. # refresh the credentials
  232. refreshed_credentials = oauth_handler.refresh_credentials(
  233. tenant_id=tenant_id,
  234. user_id=builtin_provider.user_id,
  235. plugin_id=tool_provider.plugin_id,
  236. provider=provider_name,
  237. redirect_uri=redirect_uri,
  238. system_credentials=system_credentials or {},
  239. credentials=decrypted_credentials,
  240. )
  241. # update the credentials
  242. builtin_provider.encrypted_credentials = (
  243. TypeAdapter(dict[str, Any])
  244. .dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
  245. .decode("utf-8")
  246. )
  247. builtin_provider.expires_at = refreshed_credentials.expires_at
  248. db.session.commit()
  249. decrypted_credentials = refreshed_credentials.credentials
  250. cache.delete()
  251. return cast(
  252. BuiltinTool,
  253. builtin_tool.fork_tool_runtime(
  254. runtime=ToolRuntime(
  255. tenant_id=tenant_id,
  256. credentials=dict(decrypted_credentials),
  257. credential_type=CredentialType.of(builtin_provider.credential_type),
  258. runtime_parameters={},
  259. invoke_from=invoke_from,
  260. tool_invoke_from=tool_invoke_from,
  261. )
  262. ),
  263. )
  264. elif provider_type == ToolProviderType.API:
  265. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  266. encrypter, _ = create_tool_provider_encrypter(
  267. tenant_id=tenant_id,
  268. controller=api_provider,
  269. )
  270. return cast(
  271. ApiTool,
  272. api_provider.get_tool(tool_name).fork_tool_runtime(
  273. runtime=ToolRuntime(
  274. tenant_id=tenant_id,
  275. credentials=encrypter.decrypt(credentials),
  276. invoke_from=invoke_from,
  277. tool_invoke_from=tool_invoke_from,
  278. )
  279. ),
  280. )
  281. elif provider_type == ToolProviderType.WORKFLOW:
  282. workflow_provider_stmt = select(WorkflowToolProvider).where(
  283. WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
  284. )
  285. workflow_provider = db.session.scalar(workflow_provider_stmt)
  286. if workflow_provider is None:
  287. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  288. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  289. controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
  290. if controller_tools is None or len(controller_tools) == 0:
  291. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  292. return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  293. runtime=ToolRuntime(
  294. tenant_id=tenant_id,
  295. credentials={},
  296. invoke_from=invoke_from,
  297. tool_invoke_from=tool_invoke_from,
  298. )
  299. )
  300. elif provider_type == ToolProviderType.APP:
  301. raise NotImplementedError("app provider not implemented")
  302. elif provider_type == ToolProviderType.PLUGIN:
  303. return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
  304. elif provider_type == ToolProviderType.MCP:
  305. return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
  306. else:
  307. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  308. @classmethod
  309. def get_agent_tool_runtime(
  310. cls,
  311. tenant_id: str,
  312. app_id: str,
  313. agent_tool: AgentToolEntity,
  314. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  315. variable_pool: Optional[VariablePool] = None,
  316. ) -> Tool:
  317. """
  318. get the agent tool runtime
  319. """
  320. tool_entity = cls.get_tool_runtime(
  321. provider_type=agent_tool.provider_type,
  322. provider_id=agent_tool.provider_id,
  323. tool_name=agent_tool.tool_name,
  324. tenant_id=tenant_id,
  325. invoke_from=invoke_from,
  326. tool_invoke_from=ToolInvokeFrom.AGENT,
  327. credential_id=agent_tool.credential_id,
  328. )
  329. runtime_parameters = {}
  330. parameters = tool_entity.get_merged_runtime_parameters()
  331. runtime_parameters = cls._convert_tool_parameters_type(
  332. parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
  333. )
  334. # decrypt runtime parameters
  335. encryption_manager = ToolParameterConfigurationManager(
  336. tenant_id=tenant_id,
  337. tool_runtime=tool_entity,
  338. provider_name=agent_tool.provider_id,
  339. provider_type=agent_tool.provider_type,
  340. identity_id=f"AGENT.{app_id}",
  341. )
  342. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  343. if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
  344. raise ValueError("runtime not found or runtime parameters not found")
  345. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  346. return tool_entity
  347. @classmethod
  348. def get_workflow_tool_runtime(
  349. cls,
  350. tenant_id: str,
  351. app_id: str,
  352. node_id: str,
  353. workflow_tool: "ToolEntity",
  354. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  355. variable_pool: Optional[VariablePool] = None,
  356. ) -> Tool:
  357. """
  358. get the workflow tool runtime
  359. """
  360. tool_runtime = cls.get_tool_runtime(
  361. provider_type=workflow_tool.provider_type,
  362. provider_id=workflow_tool.provider_id,
  363. tool_name=workflow_tool.tool_name,
  364. tenant_id=tenant_id,
  365. invoke_from=invoke_from,
  366. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  367. credential_id=workflow_tool.credential_id,
  368. )
  369. parameters = tool_runtime.get_merged_runtime_parameters()
  370. runtime_parameters = cls._convert_tool_parameters_type(
  371. parameters, variable_pool, workflow_tool.tool_configurations, typ="workflow"
  372. )
  373. # decrypt runtime parameters
  374. encryption_manager = ToolParameterConfigurationManager(
  375. tenant_id=tenant_id,
  376. tool_runtime=tool_runtime,
  377. provider_name=workflow_tool.provider_id,
  378. provider_type=workflow_tool.provider_type,
  379. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  380. )
  381. if runtime_parameters:
  382. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  383. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  384. return tool_runtime
  385. @classmethod
  386. def get_tool_runtime_from_plugin(
  387. cls,
  388. tool_type: ToolProviderType,
  389. tenant_id: str,
  390. provider: str,
  391. tool_name: str,
  392. tool_parameters: dict[str, Any],
  393. credential_id: Optional[str] = None,
  394. ) -> Tool:
  395. """
  396. get tool runtime from plugin
  397. """
  398. tool_entity = cls.get_tool_runtime(
  399. provider_type=tool_type,
  400. provider_id=provider,
  401. tool_name=tool_name,
  402. tenant_id=tenant_id,
  403. invoke_from=InvokeFrom.SERVICE_API,
  404. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  405. credential_id=credential_id,
  406. )
  407. runtime_parameters = {}
  408. parameters = tool_entity.get_merged_runtime_parameters()
  409. for parameter in parameters:
  410. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  411. # save tool parameter to tool entity memory
  412. value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name))
  413. runtime_parameters[parameter.name] = value
  414. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  415. return tool_entity
  416. @classmethod
  417. def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
  418. """
  419. get the absolute path of the icon of the hardcoded provider
  420. :param provider: the name of the provider
  421. :return: the absolute path of the icon, the mime type of the icon
  422. """
  423. # get provider
  424. provider_controller = cls.get_hardcoded_provider(provider)
  425. absolute_path = path.join(
  426. path.dirname(path.realpath(__file__)),
  427. "builtin_tool",
  428. "providers",
  429. provider,
  430. "_assets",
  431. provider_controller.entity.identity.icon,
  432. )
  433. # check if the icon exists
  434. if not path.exists(absolute_path):
  435. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  436. # get the mime type
  437. mime_type, _ = mimetypes.guess_type(absolute_path)
  438. mime_type = mime_type or "application/octet-stream"
  439. return absolute_path, mime_type
  440. @classmethod
  441. def list_hardcoded_providers(cls):
  442. # use cache first
  443. if cls._builtin_providers_loaded:
  444. yield from list(cls._hardcoded_providers.values())
  445. return
  446. with cls._builtin_provider_lock:
  447. if cls._builtin_providers_loaded:
  448. yield from list(cls._hardcoded_providers.values())
  449. return
  450. yield from cls._list_hardcoded_providers()
  451. @classmethod
  452. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  453. """
  454. list all the plugin providers
  455. """
  456. manager = PluginToolManager()
  457. provider_entities = manager.fetch_tool_providers(tenant_id)
  458. return [
  459. PluginToolProviderController(
  460. entity=provider.declaration,
  461. plugin_id=provider.plugin_id,
  462. plugin_unique_identifier=provider.plugin_unique_identifier,
  463. tenant_id=tenant_id,
  464. )
  465. for provider in provider_entities
  466. ]
  467. @classmethod
  468. def list_builtin_providers(
  469. cls, tenant_id: str
  470. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  471. """
  472. list all the builtin providers
  473. """
  474. yield from cls.list_hardcoded_providers()
  475. # get plugin providers
  476. yield from cls.list_plugin_providers(tenant_id)
  477. @classmethod
  478. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  479. """
  480. list all the builtin providers
  481. """
  482. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  483. if provider_path.startswith("__"):
  484. continue
  485. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  486. if provider_path.startswith("__"):
  487. continue
  488. # init provider
  489. try:
  490. provider_class = load_single_subclass_from_source(
  491. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  492. script_path=path.join(
  493. path.dirname(path.realpath(__file__)),
  494. "builtin_tool",
  495. "providers",
  496. provider_path,
  497. f"{provider_path}.py",
  498. ),
  499. parent_type=BuiltinToolProviderController,
  500. )
  501. provider: BuiltinToolProviderController = provider_class()
  502. cls._hardcoded_providers[provider.entity.identity.name] = provider
  503. for tool in provider.get_tools():
  504. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  505. yield provider
  506. except Exception:
  507. logger.exception("load builtin provider %s", provider_path)
  508. continue
  509. # set builtin providers loaded
  510. cls._builtin_providers_loaded = True
  511. @classmethod
  512. def load_hardcoded_providers_cache(cls):
  513. for _ in cls.list_hardcoded_providers():
  514. pass
  515. @classmethod
  516. def clear_hardcoded_providers_cache(cls):
  517. cls._hardcoded_providers = {}
  518. cls._builtin_providers_loaded = False
  519. @classmethod
  520. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  521. """
  522. get the tool label
  523. :param tool_name: the name of the tool
  524. :return: the label of the tool
  525. """
  526. if len(cls._builtin_tools_labels) == 0:
  527. # init the builtin providers
  528. cls.load_hardcoded_providers_cache()
  529. if tool_name not in cls._builtin_tools_labels:
  530. return None
  531. return cls._builtin_tools_labels[tool_name]
  532. @classmethod
  533. def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
  534. """
  535. list all the builtin providers
  536. """
  537. # according to multi credentials, select the one with is_default=True first, then created_at oldest
  538. # for compatibility with old version
  539. sql = """
  540. SELECT DISTINCT ON (tenant_id, provider) id
  541. FROM tool_builtin_providers
  542. WHERE tenant_id = :tenant_id
  543. ORDER BY tenant_id, provider, is_default DESC, created_at DESC
  544. """
  545. with Session(db.engine, autoflush=False) as session:
  546. ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
  547. return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
  548. @classmethod
  549. def list_providers_from_api(
  550. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  551. ) -> list[ToolProviderApiEntity]:
  552. result_providers: dict[str, ToolProviderApiEntity] = {}
  553. filters = []
  554. if not typ:
  555. filters.extend(["builtin", "api", "workflow", "mcp"])
  556. else:
  557. filters.append(typ)
  558. with db.session.no_autoflush:
  559. if "builtin" in filters:
  560. builtin_providers = cls.list_builtin_providers(tenant_id)
  561. # key: provider name, value: provider
  562. db_builtin_providers = {
  563. str(ToolProviderID(provider.provider)): provider
  564. for provider in cls.list_default_builtin_providers(tenant_id)
  565. }
  566. # append builtin providers
  567. for provider in builtin_providers:
  568. # handle include, exclude
  569. if is_filtered(
  570. include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
  571. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
  572. data=provider,
  573. name_func=lambda x: x.identity.name,
  574. ):
  575. continue
  576. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  577. provider_controller=provider,
  578. db_provider=db_builtin_providers.get(provider.entity.identity.name),
  579. decrypt_credentials=False,
  580. )
  581. if isinstance(provider, PluginToolProviderController):
  582. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  583. else:
  584. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  585. # get db api providers
  586. if "api" in filters:
  587. db_api_providers: list[ApiToolProvider] = (
  588. db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
  589. )
  590. api_provider_controllers: list[dict[str, Any]] = [
  591. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  592. for provider in db_api_providers
  593. ]
  594. # get labels
  595. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  596. for api_provider_controller in api_provider_controllers:
  597. user_provider = ToolTransformService.api_provider_to_user_provider(
  598. provider_controller=api_provider_controller["controller"],
  599. db_provider=api_provider_controller["provider"],
  600. decrypt_credentials=False,
  601. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  602. )
  603. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  604. if "workflow" in filters:
  605. # get workflow providers
  606. workflow_providers: list[WorkflowToolProvider] = (
  607. db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
  608. )
  609. workflow_provider_controllers: list[WorkflowToolProviderController] = []
  610. for workflow_provider in workflow_providers:
  611. try:
  612. workflow_provider_controllers.append(
  613. ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  614. )
  615. except Exception:
  616. # app has been deleted
  617. pass
  618. labels = ToolLabelManager.get_tools_labels(
  619. [cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
  620. )
  621. for provider_controller in workflow_provider_controllers:
  622. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  623. provider_controller=provider_controller,
  624. labels=labels.get(provider_controller.provider_id, []),
  625. )
  626. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  627. if "mcp" in filters:
  628. mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
  629. for mcp_provider in mcp_providers:
  630. result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
  631. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  632. @classmethod
  633. def get_api_provider_controller(
  634. cls, tenant_id: str, provider_id: str
  635. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  636. """
  637. get the api provider
  638. :param tenant_id: the id of the tenant
  639. :param provider_id: the id of the provider
  640. :return: the provider controller, the credentials
  641. """
  642. provider: ApiToolProvider | None = (
  643. db.session.query(ApiToolProvider)
  644. .where(
  645. ApiToolProvider.id == provider_id,
  646. ApiToolProvider.tenant_id == tenant_id,
  647. )
  648. .first()
  649. )
  650. if provider is None:
  651. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  652. auth_type = ApiProviderAuthType.NONE
  653. provider_auth_type = provider.credentials.get("auth_type")
  654. if provider_auth_type in ("api_key_header", "api_key"): # backward compatibility
  655. auth_type = ApiProviderAuthType.API_KEY_HEADER
  656. elif provider_auth_type == "api_key_query":
  657. auth_type = ApiProviderAuthType.API_KEY_QUERY
  658. controller = ApiToolProviderController.from_db(
  659. provider,
  660. auth_type,
  661. )
  662. controller.load_bundled_tools(provider.tools)
  663. return controller, provider.credentials
  664. @classmethod
  665. def get_mcp_provider_controller(cls, tenant_id: str, provider_id: str) -> MCPToolProviderController:
  666. """
  667. get the api provider
  668. :param tenant_id: the id of the tenant
  669. :param provider_id: the id of the provider
  670. :return: the provider controller, the credentials
  671. """
  672. provider: MCPToolProvider | None = (
  673. db.session.query(MCPToolProvider)
  674. .where(
  675. MCPToolProvider.server_identifier == provider_id,
  676. MCPToolProvider.tenant_id == tenant_id,
  677. )
  678. .first()
  679. )
  680. if provider is None:
  681. raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
  682. controller = MCPToolProviderController._from_db(provider)
  683. return controller
  684. @classmethod
  685. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  686. """
  687. get api provider
  688. """
  689. provider_name = provider
  690. provider_obj: ApiToolProvider | None = (
  691. db.session.query(ApiToolProvider)
  692. .where(
  693. ApiToolProvider.tenant_id == tenant_id,
  694. ApiToolProvider.name == provider,
  695. )
  696. .first()
  697. )
  698. if provider_obj is None:
  699. raise ValueError(f"you have not added provider {provider_name}")
  700. try:
  701. credentials = json.loads(provider_obj.credentials_str) or {}
  702. except Exception:
  703. credentials = {}
  704. # package tool provider controller
  705. auth_type = ApiProviderAuthType.NONE
  706. credentials_auth_type = credentials.get("auth_type")
  707. if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
  708. auth_type = ApiProviderAuthType.API_KEY_HEADER
  709. elif credentials_auth_type == "api_key_query":
  710. auth_type = ApiProviderAuthType.API_KEY_QUERY
  711. controller = ApiToolProviderController.from_db(
  712. provider_obj,
  713. auth_type,
  714. )
  715. # init tool configuration
  716. encrypter, _ = create_tool_provider_encrypter(
  717. tenant_id=tenant_id,
  718. controller=controller,
  719. )
  720. masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
  721. try:
  722. icon = json.loads(provider_obj.icon)
  723. except Exception:
  724. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  725. # add tool labels
  726. labels = ToolLabelManager.get_tool_labels(controller)
  727. return cast(
  728. dict,
  729. jsonable_encoder(
  730. {
  731. "schema_type": provider_obj.schema_type,
  732. "schema": provider_obj.schema,
  733. "tools": provider_obj.tools,
  734. "icon": icon,
  735. "description": provider_obj.description,
  736. "credentials": masked_credentials,
  737. "privacy_policy": provider_obj.privacy_policy,
  738. "custom_disclaimer": provider_obj.custom_disclaimer,
  739. "labels": labels,
  740. }
  741. ),
  742. )
  743. @classmethod
  744. def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
  745. return str(
  746. URL(dify_config.CONSOLE_API_URL or "/")
  747. / "console"
  748. / "api"
  749. / "workspaces"
  750. / "current"
  751. / "tool-provider"
  752. / "builtin"
  753. / provider_id
  754. / "icon"
  755. )
  756. @classmethod
  757. def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
  758. return str(
  759. URL(dify_config.CONSOLE_API_URL or "/")
  760. / "console"
  761. / "api"
  762. / "workspaces"
  763. / "current"
  764. / "plugin"
  765. / "icon"
  766. % {"tenant_id": tenant_id, "filename": filename}
  767. )
  768. @classmethod
  769. def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  770. try:
  771. workflow_provider: WorkflowToolProvider | None = (
  772. db.session.query(WorkflowToolProvider)
  773. .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  774. .first()
  775. )
  776. if workflow_provider is None:
  777. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  778. icon: dict = json.loads(workflow_provider.icon)
  779. return icon
  780. except Exception:
  781. return {"background": "#252525", "content": "\ud83d\ude01"}
  782. @classmethod
  783. def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  784. try:
  785. api_provider: ApiToolProvider | None = (
  786. db.session.query(ApiToolProvider)
  787. .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  788. .first()
  789. )
  790. if api_provider is None:
  791. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  792. icon: dict = json.loads(api_provider.icon)
  793. return icon
  794. except Exception:
  795. return {"background": "#252525", "content": "\ud83d\ude01"}
  796. @classmethod
  797. def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
  798. try:
  799. mcp_provider: MCPToolProvider | None = (
  800. db.session.query(MCPToolProvider)
  801. .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
  802. .first()
  803. )
  804. if mcp_provider is None:
  805. raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
  806. return mcp_provider.provider_icon
  807. except Exception:
  808. return {"background": "#252525", "content": "\ud83d\ude01"}
  809. @classmethod
  810. def get_tool_icon(
  811. cls,
  812. tenant_id: str,
  813. provider_type: ToolProviderType,
  814. provider_id: str,
  815. ) -> Union[str, dict]:
  816. """
  817. get the tool icon
  818. :param tenant_id: the id of the tenant
  819. :param provider_type: the type of the provider
  820. :param provider_id: the id of the provider
  821. :return:
  822. """
  823. provider_type = provider_type
  824. provider_id = provider_id
  825. if provider_type == ToolProviderType.BUILT_IN:
  826. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  827. if isinstance(provider, PluginToolProviderController):
  828. try:
  829. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  830. except Exception:
  831. return {"background": "#252525", "content": "\ud83d\ude01"}
  832. return cls.generate_builtin_tool_icon_url(provider_id)
  833. elif provider_type == ToolProviderType.API:
  834. return cls.generate_api_tool_icon_url(tenant_id, provider_id)
  835. elif provider_type == ToolProviderType.WORKFLOW:
  836. return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
  837. elif provider_type == ToolProviderType.PLUGIN:
  838. provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
  839. if isinstance(provider, PluginToolProviderController):
  840. try:
  841. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  842. except Exception:
  843. return {"background": "#252525", "content": "\ud83d\ude01"}
  844. raise ValueError(f"plugin provider {provider_id} not found")
  845. elif provider_type == ToolProviderType.MCP:
  846. return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
  847. else:
  848. raise ValueError(f"provider type {provider_type} not found")
  849. @classmethod
  850. def _convert_tool_parameters_type(
  851. cls,
  852. parameters: list[ToolParameter],
  853. variable_pool: Optional[VariablePool],
  854. tool_configurations: dict[str, Any],
  855. typ: Literal["agent", "workflow", "tool"] = "workflow",
  856. ) -> dict[str, Any]:
  857. """
  858. Convert tool parameters type
  859. """
  860. from core.workflow.nodes.tool.entities import ToolNodeData
  861. from core.workflow.nodes.tool.exc import ToolParameterError
  862. runtime_parameters = {}
  863. for parameter in parameters:
  864. if (
  865. parameter.type
  866. in {
  867. ToolParameter.ToolParameterType.SYSTEM_FILES,
  868. ToolParameter.ToolParameterType.FILE,
  869. ToolParameter.ToolParameterType.FILES,
  870. }
  871. and parameter.required
  872. and typ == "agent"
  873. ):
  874. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  875. # save tool parameter to tool entity memory
  876. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  877. if variable_pool:
  878. config = tool_configurations.get(parameter.name, {})
  879. if not (config and isinstance(config, dict) and config.get("value") is not None):
  880. continue
  881. tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
  882. if tool_input.type == "variable":
  883. variable = variable_pool.get(tool_input.value)
  884. if variable is None:
  885. raise ToolParameterError(f"Variable {tool_input.value} does not exist")
  886. parameter_value = variable.value
  887. elif tool_input.type == "constant":
  888. parameter_value = tool_input.value
  889. elif tool_input.type == "mixed":
  890. segment_group = variable_pool.convert_template(str(tool_input.value))
  891. parameter_value = segment_group.text
  892. else:
  893. raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
  894. runtime_parameters[parameter.name] = parameter_value
  895. else:
  896. value = parameter.init_frontend_parameter(tool_configurations.get(parameter.name))
  897. runtime_parameters[parameter.name] = value
  898. return runtime_parameters
  899. ToolManager.load_hardcoded_providers_cache()