You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tool_manager.py 41KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027
  1. import json
  2. import logging
  3. import mimetypes
  4. import time
  5. from collections.abc import Generator, Mapping
  6. from os import listdir, path
  7. from threading import Lock
  8. from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
  9. from pydantic import TypeAdapter
  10. from yarl import URL
  11. import contexts
  12. from core.helper.provider_cache import ToolProviderCredentialsCache
  13. from core.plugin.entities.plugin import ToolProviderID
  14. from core.plugin.impl.oauth import OAuthHandler
  15. from core.plugin.impl.tool import PluginToolManager
  16. from core.tools.__base.tool_provider import ToolProviderController
  17. from core.tools.__base.tool_runtime import ToolRuntime
  18. from core.tools.mcp_tool.provider import MCPToolProviderController
  19. from core.tools.mcp_tool.tool import MCPTool
  20. from core.tools.plugin_tool.provider import PluginToolProviderController
  21. from core.tools.plugin_tool.tool import PluginTool
  22. from core.tools.utils.uuid_utils import is_valid_uuid
  23. from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
  24. from core.workflow.entities.variable_pool import VariablePool
  25. from services.tools.mcp_tools_manage_service import MCPToolManageService
  26. if TYPE_CHECKING:
  27. from core.workflow.nodes.tool.entities import ToolEntity
  28. from configs import dify_config
  29. from core.agent.entities import AgentToolEntity
  30. from core.app.entities.app_invoke_entities import InvokeFrom
  31. from core.helper.module_import_helper import load_single_subclass_from_source
  32. from core.helper.position_helper import is_filtered
  33. from core.model_runtime.utils.encoders import jsonable_encoder
  34. from core.tools.__base.tool import Tool
  35. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  36. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  37. from core.tools.builtin_tool.tool import BuiltinTool
  38. from core.tools.custom_tool.provider import ApiToolProviderController
  39. from core.tools.custom_tool.tool import ApiTool
  40. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  41. from core.tools.entities.common_entities import I18nObject
  42. from core.tools.entities.tool_entities import (
  43. ApiProviderAuthType,
  44. CredentialType,
  45. ToolInvokeFrom,
  46. ToolParameter,
  47. ToolProviderType,
  48. )
  49. from core.tools.errors import ToolProviderNotFoundError
  50. from core.tools.tool_label_manager import ToolLabelManager
  51. from core.tools.utils.configuration import (
  52. ToolParameterConfigurationManager,
  53. )
  54. from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
  55. from core.tools.workflow_as_tool.tool import WorkflowTool
  56. from extensions.ext_database import db
  57. from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
  58. from services.tools.tools_transform_service import ToolTransformService
  59. logger = logging.getLogger(__name__)
  60. class ToolManager:
  61. _builtin_provider_lock = Lock()
  62. _hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
  63. _builtin_providers_loaded = False
  64. _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
  65. @classmethod
  66. def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
  67. """
  68. get the hardcoded provider
  69. """
  70. if len(cls._hardcoded_providers) == 0:
  71. # init the builtin providers
  72. cls.load_hardcoded_providers_cache()
  73. return cls._hardcoded_providers[provider]
  74. @classmethod
  75. def get_builtin_provider(
  76. cls, provider: str, tenant_id: str
  77. ) -> BuiltinToolProviderController | PluginToolProviderController:
  78. """
  79. get the builtin provider
  80. :param provider: the name of the provider
  81. :param tenant_id: the id of the tenant
  82. :return: the provider
  83. """
  84. # split provider to
  85. if len(cls._hardcoded_providers) == 0:
  86. # init the builtin providers
  87. cls.load_hardcoded_providers_cache()
  88. if provider not in cls._hardcoded_providers:
  89. # get plugin provider
  90. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  91. if plugin_provider:
  92. return plugin_provider
  93. return cls._hardcoded_providers[provider]
  94. @classmethod
  95. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  96. """
  97. get the plugin provider
  98. """
  99. # check if context is set
  100. try:
  101. contexts.plugin_tool_providers.get()
  102. except LookupError:
  103. contexts.plugin_tool_providers.set({})
  104. contexts.plugin_tool_providers_lock.set(Lock())
  105. plugin_tool_providers = contexts.plugin_tool_providers.get()
  106. if provider in plugin_tool_providers:
  107. return plugin_tool_providers[provider]
  108. with contexts.plugin_tool_providers_lock.get():
  109. # double check
  110. plugin_tool_providers = contexts.plugin_tool_providers.get()
  111. if provider in plugin_tool_providers:
  112. return plugin_tool_providers[provider]
  113. manager = PluginToolManager()
  114. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  115. if not provider_entity:
  116. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  117. controller = PluginToolProviderController(
  118. entity=provider_entity.declaration,
  119. plugin_id=provider_entity.plugin_id,
  120. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  121. tenant_id=tenant_id,
  122. )
  123. plugin_tool_providers[provider] = controller
  124. return controller
  125. @classmethod
  126. def get_tool_runtime(
  127. cls,
  128. provider_type: ToolProviderType,
  129. provider_id: str,
  130. tool_name: str,
  131. tenant_id: str,
  132. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  133. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  134. credential_id: Optional[str] = None,
  135. ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
  136. """
  137. get the tool runtime
  138. :param provider_type: the type of the provider
  139. :param provider_id: the id of the provider
  140. :param tool_name: the name of the tool
  141. :param tenant_id: the tenant id
  142. :param invoke_from: invoke from
  143. :param tool_invoke_from: the tool invoke from
  144. :param credential_id: the credential id
  145. :return: the tool
  146. """
  147. if provider_type == ToolProviderType.BUILT_IN:
  148. # check if the builtin tool need credentials
  149. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  150. builtin_tool = provider_controller.get_tool(tool_name)
  151. if not builtin_tool:
  152. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  153. if not provider_controller.need_credentials:
  154. return cast(
  155. BuiltinTool,
  156. builtin_tool.fork_tool_runtime(
  157. runtime=ToolRuntime(
  158. tenant_id=tenant_id,
  159. credentials={},
  160. invoke_from=invoke_from,
  161. tool_invoke_from=tool_invoke_from,
  162. )
  163. ),
  164. )
  165. builtin_provider = None
  166. if isinstance(provider_controller, PluginToolProviderController):
  167. provider_id_entity = ToolProviderID(provider_id)
  168. # get specific credentials
  169. if is_valid_uuid(credential_id):
  170. try:
  171. builtin_provider = (
  172. db.session.query(BuiltinToolProvider)
  173. .filter(
  174. BuiltinToolProvider.tenant_id == tenant_id,
  175. BuiltinToolProvider.id == credential_id,
  176. )
  177. .first()
  178. )
  179. except Exception as e:
  180. builtin_provider = None
  181. logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True)
  182. # if the provider has been deleted, raise an error
  183. if builtin_provider is None:
  184. raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
  185. # fallback to the default provider
  186. if builtin_provider is None:
  187. # use the default provider
  188. builtin_provider = (
  189. db.session.query(BuiltinToolProvider)
  190. .filter(
  191. BuiltinToolProvider.tenant_id == tenant_id,
  192. (BuiltinToolProvider.provider == str(provider_id_entity))
  193. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  194. )
  195. .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
  196. .first()
  197. )
  198. if builtin_provider is None:
  199. raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
  200. else:
  201. builtin_provider = (
  202. db.session.query(BuiltinToolProvider)
  203. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  204. .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
  205. .first()
  206. )
  207. if builtin_provider is None:
  208. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  209. encrypter, _ = create_provider_encrypter(
  210. tenant_id=tenant_id,
  211. config=[
  212. x.to_basic_provider_config()
  213. for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
  214. ],
  215. cache=ToolProviderCredentialsCache(
  216. tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
  217. ),
  218. )
  219. # decrypt the credentials
  220. decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
  221. # check if the credentials is expired
  222. if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
  223. # TODO: circular import
  224. from services.tools.builtin_tools_manage_service import BuiltinToolManageService
  225. # refresh the credentials
  226. tool_provider = ToolProviderID(provider_id)
  227. provider_name = tool_provider.provider_name
  228. redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
  229. system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
  230. oauth_handler = OAuthHandler()
  231. # refresh the credentials
  232. refreshed_credentials = oauth_handler.refresh_credentials(
  233. tenant_id=tenant_id,
  234. user_id=builtin_provider.user_id,
  235. plugin_id=tool_provider.plugin_id,
  236. provider=provider_name,
  237. redirect_uri=redirect_uri,
  238. system_credentials=system_credentials or {},
  239. credentials=decrypted_credentials,
  240. )
  241. # update the credentials
  242. builtin_provider.encrypted_credentials = (
  243. TypeAdapter(dict[str, Any])
  244. .dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
  245. .decode("utf-8")
  246. )
  247. builtin_provider.expires_at = refreshed_credentials.expires_at
  248. db.session.commit()
  249. decrypted_credentials = refreshed_credentials.credentials
  250. return cast(
  251. BuiltinTool,
  252. builtin_tool.fork_tool_runtime(
  253. runtime=ToolRuntime(
  254. tenant_id=tenant_id,
  255. credentials=dict(decrypted_credentials),
  256. credential_type=CredentialType.of(builtin_provider.credential_type),
  257. runtime_parameters={},
  258. invoke_from=invoke_from,
  259. tool_invoke_from=tool_invoke_from,
  260. )
  261. ),
  262. )
  263. elif provider_type == ToolProviderType.API:
  264. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  265. encrypter, _ = create_tool_provider_encrypter(
  266. tenant_id=tenant_id,
  267. controller=api_provider,
  268. )
  269. return cast(
  270. ApiTool,
  271. api_provider.get_tool(tool_name).fork_tool_runtime(
  272. runtime=ToolRuntime(
  273. tenant_id=tenant_id,
  274. credentials=encrypter.decrypt(credentials),
  275. invoke_from=invoke_from,
  276. tool_invoke_from=tool_invoke_from,
  277. )
  278. ),
  279. )
  280. elif provider_type == ToolProviderType.WORKFLOW:
  281. workflow_provider = (
  282. db.session.query(WorkflowToolProvider)
  283. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  284. .first()
  285. )
  286. if workflow_provider is None:
  287. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  288. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  289. controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
  290. if controller_tools is None or len(controller_tools) == 0:
  291. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  292. return cast(
  293. WorkflowTool,
  294. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  295. runtime=ToolRuntime(
  296. tenant_id=tenant_id,
  297. credentials={},
  298. invoke_from=invoke_from,
  299. tool_invoke_from=tool_invoke_from,
  300. )
  301. ),
  302. )
  303. elif provider_type == ToolProviderType.APP:
  304. raise NotImplementedError("app provider not implemented")
  305. elif provider_type == ToolProviderType.PLUGIN:
  306. return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
  307. elif provider_type == ToolProviderType.MCP:
  308. return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
  309. else:
  310. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  311. @classmethod
  312. def get_agent_tool_runtime(
  313. cls,
  314. tenant_id: str,
  315. app_id: str,
  316. agent_tool: AgentToolEntity,
  317. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  318. variable_pool: Optional[VariablePool] = None,
  319. ) -> Tool:
  320. """
  321. get the agent tool runtime
  322. """
  323. tool_entity = cls.get_tool_runtime(
  324. provider_type=agent_tool.provider_type,
  325. provider_id=agent_tool.provider_id,
  326. tool_name=agent_tool.tool_name,
  327. tenant_id=tenant_id,
  328. invoke_from=invoke_from,
  329. tool_invoke_from=ToolInvokeFrom.AGENT,
  330. credential_id=agent_tool.credential_id,
  331. )
  332. runtime_parameters = {}
  333. parameters = tool_entity.get_merged_runtime_parameters()
  334. runtime_parameters = cls._convert_tool_parameters_type(
  335. parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
  336. )
  337. # decrypt runtime parameters
  338. encryption_manager = ToolParameterConfigurationManager(
  339. tenant_id=tenant_id,
  340. tool_runtime=tool_entity,
  341. provider_name=agent_tool.provider_id,
  342. provider_type=agent_tool.provider_type,
  343. identity_id=f"AGENT.{app_id}",
  344. )
  345. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  346. if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
  347. raise ValueError("runtime not found or runtime parameters not found")
  348. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  349. return tool_entity
  350. @classmethod
  351. def get_workflow_tool_runtime(
  352. cls,
  353. tenant_id: str,
  354. app_id: str,
  355. node_id: str,
  356. workflow_tool: "ToolEntity",
  357. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  358. variable_pool: Optional[VariablePool] = None,
  359. ) -> Tool:
  360. """
  361. get the workflow tool runtime
  362. """
  363. tool_runtime = cls.get_tool_runtime(
  364. provider_type=workflow_tool.provider_type,
  365. provider_id=workflow_tool.provider_id,
  366. tool_name=workflow_tool.tool_name,
  367. tenant_id=tenant_id,
  368. invoke_from=invoke_from,
  369. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  370. credential_id=workflow_tool.credential_id,
  371. )
  372. parameters = tool_runtime.get_merged_runtime_parameters()
  373. runtime_parameters = cls._convert_tool_parameters_type(
  374. parameters, variable_pool, workflow_tool.tool_configurations, typ="workflow"
  375. )
  376. # decrypt runtime parameters
  377. encryption_manager = ToolParameterConfigurationManager(
  378. tenant_id=tenant_id,
  379. tool_runtime=tool_runtime,
  380. provider_name=workflow_tool.provider_id,
  381. provider_type=workflow_tool.provider_type,
  382. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  383. )
  384. if runtime_parameters:
  385. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  386. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  387. return tool_runtime
  388. @classmethod
  389. def get_tool_runtime_from_plugin(
  390. cls,
  391. tool_type: ToolProviderType,
  392. tenant_id: str,
  393. provider: str,
  394. tool_name: str,
  395. tool_parameters: dict[str, Any],
  396. credential_id: Optional[str] = None,
  397. ) -> Tool:
  398. """
  399. get tool runtime from plugin
  400. """
  401. tool_entity = cls.get_tool_runtime(
  402. provider_type=tool_type,
  403. provider_id=provider,
  404. tool_name=tool_name,
  405. tenant_id=tenant_id,
  406. invoke_from=InvokeFrom.SERVICE_API,
  407. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  408. credential_id=credential_id,
  409. )
  410. runtime_parameters = {}
  411. parameters = tool_entity.get_merged_runtime_parameters()
  412. for parameter in parameters:
  413. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  414. # save tool parameter to tool entity memory
  415. value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name))
  416. runtime_parameters[parameter.name] = value
  417. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  418. return tool_entity
  419. @classmethod
  420. def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
  421. """
  422. get the absolute path of the icon of the hardcoded provider
  423. :param provider: the name of the provider
  424. :return: the absolute path of the icon, the mime type of the icon
  425. """
  426. # get provider
  427. provider_controller = cls.get_hardcoded_provider(provider)
  428. absolute_path = path.join(
  429. path.dirname(path.realpath(__file__)),
  430. "builtin_tool",
  431. "providers",
  432. provider,
  433. "_assets",
  434. provider_controller.entity.identity.icon,
  435. )
  436. # check if the icon exists
  437. if not path.exists(absolute_path):
  438. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  439. # get the mime type
  440. mime_type, _ = mimetypes.guess_type(absolute_path)
  441. mime_type = mime_type or "application/octet-stream"
  442. return absolute_path, mime_type
  443. @classmethod
  444. def list_hardcoded_providers(cls):
  445. # use cache first
  446. if cls._builtin_providers_loaded:
  447. yield from list(cls._hardcoded_providers.values())
  448. return
  449. with cls._builtin_provider_lock:
  450. if cls._builtin_providers_loaded:
  451. yield from list(cls._hardcoded_providers.values())
  452. return
  453. yield from cls._list_hardcoded_providers()
  454. @classmethod
  455. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  456. """
  457. list all the plugin providers
  458. """
  459. manager = PluginToolManager()
  460. provider_entities = manager.fetch_tool_providers(tenant_id)
  461. return [
  462. PluginToolProviderController(
  463. entity=provider.declaration,
  464. plugin_id=provider.plugin_id,
  465. plugin_unique_identifier=provider.plugin_unique_identifier,
  466. tenant_id=tenant_id,
  467. )
  468. for provider in provider_entities
  469. ]
  470. @classmethod
  471. def list_builtin_providers(
  472. cls, tenant_id: str
  473. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  474. """
  475. list all the builtin providers
  476. """
  477. yield from cls.list_hardcoded_providers()
  478. # get plugin providers
  479. yield from cls.list_plugin_providers(tenant_id)
  480. @classmethod
  481. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  482. """
  483. list all the builtin providers
  484. """
  485. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  486. if provider_path.startswith("__"):
  487. continue
  488. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  489. if provider_path.startswith("__"):
  490. continue
  491. # init provider
  492. try:
  493. provider_class = load_single_subclass_from_source(
  494. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  495. script_path=path.join(
  496. path.dirname(path.realpath(__file__)),
  497. "builtin_tool",
  498. "providers",
  499. provider_path,
  500. f"{provider_path}.py",
  501. ),
  502. parent_type=BuiltinToolProviderController,
  503. )
  504. provider: BuiltinToolProviderController = provider_class()
  505. cls._hardcoded_providers[provider.entity.identity.name] = provider
  506. for tool in provider.get_tools():
  507. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  508. yield provider
  509. except Exception:
  510. logger.exception(f"load builtin provider {provider_path}")
  511. continue
  512. # set builtin providers loaded
  513. cls._builtin_providers_loaded = True
  514. @classmethod
  515. def load_hardcoded_providers_cache(cls):
  516. for _ in cls.list_hardcoded_providers():
  517. pass
  518. @classmethod
  519. def clear_hardcoded_providers_cache(cls):
  520. cls._hardcoded_providers = {}
  521. cls._builtin_providers_loaded = False
  522. @classmethod
  523. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  524. """
  525. get the tool label
  526. :param tool_name: the name of the tool
  527. :return: the label of the tool
  528. """
  529. if len(cls._builtin_tools_labels) == 0:
  530. # init the builtin providers
  531. cls.load_hardcoded_providers_cache()
  532. if tool_name not in cls._builtin_tools_labels:
  533. return None
  534. return cls._builtin_tools_labels[tool_name]
  535. @classmethod
  536. def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
  537. """
  538. list all the builtin providers
  539. """
  540. # according to multi credentials, select the one with is_default=True first, then created_at oldest
  541. # for compatibility with old version
  542. sql = """
  543. SELECT DISTINCT ON (tenant_id, provider) id
  544. FROM tool_builtin_providers
  545. WHERE tenant_id = :tenant_id
  546. ORDER BY tenant_id, provider, is_default DESC, created_at DESC
  547. """
  548. ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
  549. return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
  550. @classmethod
  551. def list_providers_from_api(
  552. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  553. ) -> list[ToolProviderApiEntity]:
  554. result_providers: dict[str, ToolProviderApiEntity] = {}
  555. filters = []
  556. if not typ:
  557. filters.extend(["builtin", "api", "workflow", "mcp"])
  558. else:
  559. filters.append(typ)
  560. with db.session.no_autoflush:
  561. if "builtin" in filters:
  562. builtin_providers = cls.list_builtin_providers(tenant_id)
  563. # key: provider name, value: provider
  564. db_builtin_providers = {
  565. str(ToolProviderID(provider.provider)): provider
  566. for provider in cls.list_default_builtin_providers(tenant_id)
  567. }
  568. # append builtin providers
  569. for provider in builtin_providers:
  570. # handle include, exclude
  571. if is_filtered(
  572. include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
  573. exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
  574. data=provider,
  575. name_func=lambda x: x.identity.name,
  576. ):
  577. continue
  578. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  579. provider_controller=provider,
  580. db_provider=db_builtin_providers.get(provider.entity.identity.name),
  581. decrypt_credentials=False,
  582. )
  583. if isinstance(provider, PluginToolProviderController):
  584. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  585. else:
  586. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  587. # get db api providers
  588. if "api" in filters:
  589. db_api_providers: list[ApiToolProvider] = (
  590. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  591. )
  592. api_provider_controllers: list[dict[str, Any]] = [
  593. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  594. for provider in db_api_providers
  595. ]
  596. # get labels
  597. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  598. for api_provider_controller in api_provider_controllers:
  599. user_provider = ToolTransformService.api_provider_to_user_provider(
  600. provider_controller=api_provider_controller["controller"],
  601. db_provider=api_provider_controller["provider"],
  602. decrypt_credentials=False,
  603. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  604. )
  605. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  606. if "workflow" in filters:
  607. # get workflow providers
  608. workflow_providers: list[WorkflowToolProvider] = (
  609. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  610. )
  611. workflow_provider_controllers: list[WorkflowToolProviderController] = []
  612. for workflow_provider in workflow_providers:
  613. try:
  614. workflow_provider_controllers.append(
  615. ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  616. )
  617. except Exception:
  618. # app has been deleted
  619. pass
  620. labels = ToolLabelManager.get_tools_labels(
  621. [cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
  622. )
  623. for provider_controller in workflow_provider_controllers:
  624. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  625. provider_controller=provider_controller,
  626. labels=labels.get(provider_controller.provider_id, []),
  627. )
  628. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  629. if "mcp" in filters:
  630. mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
  631. for mcp_provider in mcp_providers:
  632. result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
  633. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  634. @classmethod
  635. def get_api_provider_controller(
  636. cls, tenant_id: str, provider_id: str
  637. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  638. """
  639. get the api provider
  640. :param tenant_id: the id of the tenant
  641. :param provider_id: the id of the provider
  642. :return: the provider controller, the credentials
  643. """
  644. provider: ApiToolProvider | None = (
  645. db.session.query(ApiToolProvider)
  646. .filter(
  647. ApiToolProvider.id == provider_id,
  648. ApiToolProvider.tenant_id == tenant_id,
  649. )
  650. .first()
  651. )
  652. if provider is None:
  653. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  654. auth_type = ApiProviderAuthType.NONE
  655. provider_auth_type = provider.credentials.get("auth_type")
  656. if provider_auth_type in ("api_key_header", "api_key"): # backward compatibility
  657. auth_type = ApiProviderAuthType.API_KEY_HEADER
  658. elif provider_auth_type == "api_key_query":
  659. auth_type = ApiProviderAuthType.API_KEY_QUERY
  660. controller = ApiToolProviderController.from_db(
  661. provider,
  662. auth_type,
  663. )
  664. controller.load_bundled_tools(provider.tools)
  665. return controller, provider.credentials
  666. @classmethod
  667. def get_mcp_provider_controller(cls, tenant_id: str, provider_id: str) -> MCPToolProviderController:
  668. """
  669. get the api provider
  670. :param tenant_id: the id of the tenant
  671. :param provider_id: the id of the provider
  672. :return: the provider controller, the credentials
  673. """
  674. provider: MCPToolProvider | None = (
  675. db.session.query(MCPToolProvider)
  676. .filter(
  677. MCPToolProvider.server_identifier == provider_id,
  678. MCPToolProvider.tenant_id == tenant_id,
  679. )
  680. .first()
  681. )
  682. if provider is None:
  683. raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
  684. controller = MCPToolProviderController._from_db(provider)
  685. return controller
  686. @classmethod
  687. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  688. """
  689. get api provider
  690. """
  691. """
  692. get tool provider
  693. """
  694. provider_name = provider
  695. provider_obj: ApiToolProvider | None = (
  696. db.session.query(ApiToolProvider)
  697. .filter(
  698. ApiToolProvider.tenant_id == tenant_id,
  699. ApiToolProvider.name == provider,
  700. )
  701. .first()
  702. )
  703. if provider_obj is None:
  704. raise ValueError(f"you have not added provider {provider_name}")
  705. try:
  706. credentials = json.loads(provider_obj.credentials_str) or {}
  707. except Exception:
  708. credentials = {}
  709. # package tool provider controller
  710. auth_type = ApiProviderAuthType.NONE
  711. credentials_auth_type = credentials.get("auth_type")
  712. if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
  713. auth_type = ApiProviderAuthType.API_KEY_HEADER
  714. elif credentials_auth_type == "api_key_query":
  715. auth_type = ApiProviderAuthType.API_KEY_QUERY
  716. controller = ApiToolProviderController.from_db(
  717. provider_obj,
  718. auth_type,
  719. )
  720. # init tool configuration
  721. encrypter, _ = create_tool_provider_encrypter(
  722. tenant_id=tenant_id,
  723. controller=controller,
  724. )
  725. masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
  726. try:
  727. icon = json.loads(provider_obj.icon)
  728. except Exception:
  729. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  730. # add tool labels
  731. labels = ToolLabelManager.get_tool_labels(controller)
  732. return cast(
  733. dict,
  734. jsonable_encoder(
  735. {
  736. "schema_type": provider_obj.schema_type,
  737. "schema": provider_obj.schema,
  738. "tools": provider_obj.tools,
  739. "icon": icon,
  740. "description": provider_obj.description,
  741. "credentials": masked_credentials,
  742. "privacy_policy": provider_obj.privacy_policy,
  743. "custom_disclaimer": provider_obj.custom_disclaimer,
  744. "labels": labels,
  745. }
  746. ),
  747. )
  748. @classmethod
  749. def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
  750. return str(
  751. URL(dify_config.CONSOLE_API_URL or "/")
  752. / "console"
  753. / "api"
  754. / "workspaces"
  755. / "current"
  756. / "tool-provider"
  757. / "builtin"
  758. / provider_id
  759. / "icon"
  760. )
  761. @classmethod
  762. def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
  763. return str(
  764. URL(dify_config.CONSOLE_API_URL or "/")
  765. / "console"
  766. / "api"
  767. / "workspaces"
  768. / "current"
  769. / "plugin"
  770. / "icon"
  771. % {"tenant_id": tenant_id, "filename": filename}
  772. )
  773. @classmethod
  774. def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  775. try:
  776. workflow_provider: WorkflowToolProvider | None = (
  777. db.session.query(WorkflowToolProvider)
  778. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  779. .first()
  780. )
  781. if workflow_provider is None:
  782. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  783. icon: dict = json.loads(workflow_provider.icon)
  784. return icon
  785. except Exception:
  786. return {"background": "#252525", "content": "\ud83d\ude01"}
  787. @classmethod
  788. def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  789. try:
  790. api_provider: ApiToolProvider | None = (
  791. db.session.query(ApiToolProvider)
  792. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  793. .first()
  794. )
  795. if api_provider is None:
  796. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  797. icon: dict = json.loads(api_provider.icon)
  798. return icon
  799. except Exception:
  800. return {"background": "#252525", "content": "\ud83d\ude01"}
  801. @classmethod
  802. def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
  803. try:
  804. mcp_provider: MCPToolProvider | None = (
  805. db.session.query(MCPToolProvider)
  806. .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
  807. .first()
  808. )
  809. if mcp_provider is None:
  810. raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
  811. return mcp_provider.provider_icon
  812. except Exception:
  813. return {"background": "#252525", "content": "\ud83d\ude01"}
  814. @classmethod
  815. def get_tool_icon(
  816. cls,
  817. tenant_id: str,
  818. provider_type: ToolProviderType,
  819. provider_id: str,
  820. ) -> Union[str, dict]:
  821. """
  822. get the tool icon
  823. :param tenant_id: the id of the tenant
  824. :param provider_type: the type of the provider
  825. :param provider_id: the id of the provider
  826. :return:
  827. """
  828. provider_type = provider_type
  829. provider_id = provider_id
  830. if provider_type == ToolProviderType.BUILT_IN:
  831. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  832. if isinstance(provider, PluginToolProviderController):
  833. try:
  834. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  835. except Exception:
  836. return {"background": "#252525", "content": "\ud83d\ude01"}
  837. return cls.generate_builtin_tool_icon_url(provider_id)
  838. elif provider_type == ToolProviderType.API:
  839. return cls.generate_api_tool_icon_url(tenant_id, provider_id)
  840. elif provider_type == ToolProviderType.WORKFLOW:
  841. return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
  842. elif provider_type == ToolProviderType.PLUGIN:
  843. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  844. if isinstance(provider, PluginToolProviderController):
  845. try:
  846. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  847. except Exception:
  848. return {"background": "#252525", "content": "\ud83d\ude01"}
  849. raise ValueError(f"plugin provider {provider_id} not found")
  850. elif provider_type == ToolProviderType.MCP:
  851. return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
  852. else:
  853. raise ValueError(f"provider type {provider_type} not found")
  854. @classmethod
  855. def _convert_tool_parameters_type(
  856. cls,
  857. parameters: list[ToolParameter],
  858. variable_pool: Optional[VariablePool],
  859. tool_configurations: dict[str, Any],
  860. typ: Literal["agent", "workflow", "tool"] = "workflow",
  861. ) -> dict[str, Any]:
  862. """
  863. Convert tool parameters type
  864. """
  865. from core.workflow.nodes.tool.entities import ToolNodeData
  866. from core.workflow.nodes.tool.exc import ToolParameterError
  867. runtime_parameters = {}
  868. for parameter in parameters:
  869. if (
  870. parameter.type
  871. in {
  872. ToolParameter.ToolParameterType.SYSTEM_FILES,
  873. ToolParameter.ToolParameterType.FILE,
  874. ToolParameter.ToolParameterType.FILES,
  875. }
  876. and parameter.required
  877. and typ == "agent"
  878. ):
  879. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  880. # save tool parameter to tool entity memory
  881. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  882. if variable_pool:
  883. config = tool_configurations.get(parameter.name, {})
  884. if not (config and isinstance(config, dict) and config.get("value") is not None):
  885. continue
  886. tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
  887. if tool_input.type == "variable":
  888. variable = variable_pool.get(tool_input.value)
  889. if variable is None:
  890. raise ToolParameterError(f"Variable {tool_input.value} does not exist")
  891. parameter_value = variable.value
  892. elif tool_input.type in {"mixed", "constant"}:
  893. segment_group = variable_pool.convert_template(str(tool_input.value))
  894. parameter_value = segment_group.text
  895. else:
  896. raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
  897. runtime_parameters[parameter.name] = parameter_value
  898. else:
  899. value = parameter.init_frontend_parameter(tool_configurations.get(parameter.name))
  900. runtime_parameters[parameter.name] = value
  901. return runtime_parameters
  902. ToolManager.load_hardcoded_providers_cache()