You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959
  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
  8. from yarl import URL
  9. import contexts
  10. from core.plugin.entities.plugin import ToolProviderID
  11. from core.plugin.impl.tool import PluginToolManager
  12. from core.tools.__base.tool_provider import ToolProviderController
  13. from core.tools.__base.tool_runtime import ToolRuntime
  14. from core.tools.mcp_tool.provider import MCPToolProviderController
  15. from core.tools.mcp_tool.tool import MCPTool
  16. from core.tools.plugin_tool.provider import PluginToolProviderController
  17. from core.tools.plugin_tool.tool import PluginTool
  18. from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
  19. from core.workflow.entities.variable_pool import VariablePool
  20. from services.tools.mcp_tools_mange_service import MCPToolManageService
  21. if TYPE_CHECKING:
  22. from core.workflow.nodes.tool.entities import ToolEntity
  23. from configs import dify_config
  24. from core.agent.entities import AgentToolEntity
  25. from core.app.entities.app_invoke_entities import InvokeFrom
  26. from core.helper.module_import_helper import load_single_subclass_from_source
  27. from core.helper.position_helper import is_filtered
  28. from core.model_runtime.utils.encoders import jsonable_encoder
  29. from core.tools.__base.tool import Tool
  30. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  31. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  32. from core.tools.builtin_tool.tool import BuiltinTool
  33. from core.tools.custom_tool.provider import ApiToolProviderController
  34. from core.tools.custom_tool.tool import ApiTool
  35. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  36. from core.tools.entities.common_entities import I18nObject
  37. from core.tools.entities.tool_entities import (
  38. ApiProviderAuthType,
  39. ToolInvokeFrom,
  40. ToolParameter,
  41. ToolProviderType,
  42. )
  43. from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
  44. from core.tools.tool_label_manager import ToolLabelManager
  45. from core.tools.utils.configuration import (
  46. ProviderConfigEncrypter,
  47. ToolParameterConfigurationManager,
  48. )
  49. from core.tools.workflow_as_tool.tool import WorkflowTool
  50. from extensions.ext_database import db
  51. from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
  52. from services.tools.tools_transform_service import ToolTransformService
  53. logger = logging.getLogger(__name__)
  54. class ToolManager:
  55. _builtin_provider_lock = Lock()
  56. _hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
  57. _builtin_providers_loaded = False
  58. _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
  59. @classmethod
  60. def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
  61. """
  62. get the hardcoded provider
  63. """
  64. if len(cls._hardcoded_providers) == 0:
  65. # init the builtin providers
  66. cls.load_hardcoded_providers_cache()
  67. return cls._hardcoded_providers[provider]
  68. @classmethod
  69. def get_builtin_provider(
  70. cls, provider: str, tenant_id: str
  71. ) -> BuiltinToolProviderController | PluginToolProviderController:
  72. """
  73. get the builtin provider
  74. :param provider: the name of the provider
  75. :param tenant_id: the id of the tenant
  76. :return: the provider
  77. """
  78. # split provider to
  79. if len(cls._hardcoded_providers) == 0:
  80. # init the builtin providers
  81. cls.load_hardcoded_providers_cache()
  82. if provider not in cls._hardcoded_providers:
  83. # get plugin provider
  84. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  85. if plugin_provider:
  86. return plugin_provider
  87. return cls._hardcoded_providers[provider]
  88. @classmethod
  89. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  90. """
  91. get the plugin provider
  92. """
  93. # check if context is set
  94. try:
  95. contexts.plugin_tool_providers.get()
  96. except LookupError:
  97. contexts.plugin_tool_providers.set({})
  98. contexts.plugin_tool_providers_lock.set(Lock())
  99. with contexts.plugin_tool_providers_lock.get():
  100. plugin_tool_providers = contexts.plugin_tool_providers.get()
  101. if provider in plugin_tool_providers:
  102. return plugin_tool_providers[provider]
  103. manager = PluginToolManager()
  104. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  105. if not provider_entity:
  106. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  107. controller = PluginToolProviderController(
  108. entity=provider_entity.declaration,
  109. plugin_id=provider_entity.plugin_id,
  110. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  111. tenant_id=tenant_id,
  112. )
  113. plugin_tool_providers[provider] = controller
  114. return controller
  115. @classmethod
  116. def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
  117. """
  118. get the builtin tool
  119. :param provider: the name of the provider
  120. :param tool_name: the name of the tool
  121. :param tenant_id: the id of the tenant
  122. :return: the provider, the tool
  123. """
  124. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  125. tool = provider_controller.get_tool(tool_name)
  126. if tool is None:
  127. raise ToolNotFoundError(f"tool {tool_name} not found")
  128. return tool
  129. @classmethod
  130. def get_tool_runtime(
  131. cls,
  132. provider_type: ToolProviderType,
  133. provider_id: str,
  134. tool_name: str,
  135. tenant_id: str,
  136. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  137. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  138. ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
  139. """
  140. get the tool runtime
  141. :param provider_type: the type of the provider
  142. :param provider_id: the id of the provider
  143. :param tool_name: the name of the tool
  144. :param tenant_id: the tenant id
  145. :param invoke_from: invoke from
  146. :param tool_invoke_from: the tool invoke from
  147. :return: the tool
  148. """
  149. if provider_type == ToolProviderType.BUILT_IN:
  150. # check if the builtin tool need credentials
  151. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  152. builtin_tool = provider_controller.get_tool(tool_name)
  153. if not builtin_tool:
  154. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  155. if not provider_controller.need_credentials:
  156. return cast(
  157. BuiltinTool,
  158. builtin_tool.fork_tool_runtime(
  159. runtime=ToolRuntime(
  160. tenant_id=tenant_id,
  161. credentials={},
  162. invoke_from=invoke_from,
  163. tool_invoke_from=tool_invoke_from,
  164. )
  165. ),
  166. )
  167. if isinstance(provider_controller, PluginToolProviderController):
  168. provider_id_entity = ToolProviderID(provider_id)
  169. # get credentials
  170. builtin_provider: BuiltinToolProvider | None = (
  171. db.session.query(BuiltinToolProvider)
  172. .filter(
  173. BuiltinToolProvider.tenant_id == tenant_id,
  174. (BuiltinToolProvider.provider == str(provider_id_entity))
  175. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  176. )
  177. .first()
  178. )
  179. if builtin_provider is None:
  180. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  181. else:
  182. builtin_provider = (
  183. db.session.query(BuiltinToolProvider)
  184. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  185. .first()
  186. )
  187. if builtin_provider is None:
  188. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  189. # decrypt the credentials
  190. credentials = builtin_provider.credentials
  191. tool_configuration = ProviderConfigEncrypter(
  192. tenant_id=tenant_id,
  193. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  194. provider_type=provider_controller.provider_type.value,
  195. provider_identity=provider_controller.entity.identity.name,
  196. )
  197. decrypted_credentials = tool_configuration.decrypt(credentials)
  198. return cast(
  199. BuiltinTool,
  200. builtin_tool.fork_tool_runtime(
  201. runtime=ToolRuntime(
  202. tenant_id=tenant_id,
  203. credentials=decrypted_credentials,
  204. runtime_parameters={},
  205. invoke_from=invoke_from,
  206. tool_invoke_from=tool_invoke_from,
  207. )
  208. ),
  209. )
  210. elif provider_type == ToolProviderType.API:
  211. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  212. # decrypt the credentials
  213. tool_configuration = ProviderConfigEncrypter(
  214. tenant_id=tenant_id,
  215. config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
  216. provider_type=api_provider.provider_type.value,
  217. provider_identity=api_provider.entity.identity.name,
  218. )
  219. decrypted_credentials = tool_configuration.decrypt(credentials)
  220. return cast(
  221. ApiTool,
  222. api_provider.get_tool(tool_name).fork_tool_runtime(
  223. runtime=ToolRuntime(
  224. tenant_id=tenant_id,
  225. credentials=decrypted_credentials,
  226. invoke_from=invoke_from,
  227. tool_invoke_from=tool_invoke_from,
  228. )
  229. ),
  230. )
  231. elif provider_type == ToolProviderType.WORKFLOW:
  232. workflow_provider = (
  233. db.session.query(WorkflowToolProvider)
  234. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  235. .first()
  236. )
  237. if workflow_provider is None:
  238. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  239. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  240. controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
  241. if controller_tools is None or len(controller_tools) == 0:
  242. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  243. return cast(
  244. WorkflowTool,
  245. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  246. runtime=ToolRuntime(
  247. tenant_id=tenant_id,
  248. credentials={},
  249. invoke_from=invoke_from,
  250. tool_invoke_from=tool_invoke_from,
  251. )
  252. ),
  253. )
  254. elif provider_type == ToolProviderType.APP:
  255. raise NotImplementedError("app provider not implemented")
  256. elif provider_type == ToolProviderType.PLUGIN:
  257. return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
  258. elif provider_type == ToolProviderType.MCP:
  259. return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
  260. else:
  261. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  262. @classmethod
  263. def get_agent_tool_runtime(
  264. cls,
  265. tenant_id: str,
  266. app_id: str,
  267. agent_tool: AgentToolEntity,
  268. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  269. variable_pool: Optional[VariablePool] = None,
  270. ) -> Tool:
  271. """
  272. get the agent tool runtime
  273. """
  274. tool_entity = cls.get_tool_runtime(
  275. provider_type=agent_tool.provider_type,
  276. provider_id=agent_tool.provider_id,
  277. tool_name=agent_tool.tool_name,
  278. tenant_id=tenant_id,
  279. invoke_from=invoke_from,
  280. tool_invoke_from=ToolInvokeFrom.AGENT,
  281. )
  282. runtime_parameters = {}
  283. parameters = tool_entity.get_merged_runtime_parameters()
  284. runtime_parameters = cls._convert_tool_parameters_type(
  285. parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
  286. )
  287. # decrypt runtime parameters
  288. encryption_manager = ToolParameterConfigurationManager(
  289. tenant_id=tenant_id,
  290. tool_runtime=tool_entity,
  291. provider_name=agent_tool.provider_id,
  292. provider_type=agent_tool.provider_type,
  293. identity_id=f"AGENT.{app_id}",
  294. )
  295. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  296. if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
  297. raise ValueError("runtime not found or runtime parameters not found")
  298. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  299. return tool_entity
  300. @classmethod
  301. def get_workflow_tool_runtime(
  302. cls,
  303. tenant_id: str,
  304. app_id: str,
  305. node_id: str,
  306. workflow_tool: "ToolEntity",
  307. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  308. variable_pool: Optional[VariablePool] = None,
  309. ) -> Tool:
  310. """
  311. get the workflow tool runtime
  312. """
  313. tool_runtime = cls.get_tool_runtime(
  314. provider_type=workflow_tool.provider_type,
  315. provider_id=workflow_tool.provider_id,
  316. tool_name=workflow_tool.tool_name,
  317. tenant_id=tenant_id,
  318. invoke_from=invoke_from,
  319. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  320. )
  321. parameters = tool_runtime.get_merged_runtime_parameters()
  322. runtime_parameters = cls._convert_tool_parameters_type(
  323. parameters, variable_pool, workflow_tool.tool_configurations, typ="workflow"
  324. )
  325. # decrypt runtime parameters
  326. encryption_manager = ToolParameterConfigurationManager(
  327. tenant_id=tenant_id,
  328. tool_runtime=tool_runtime,
  329. provider_name=workflow_tool.provider_id,
  330. provider_type=workflow_tool.provider_type,
  331. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  332. )
  333. if runtime_parameters:
  334. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  335. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  336. return tool_runtime
  337. @classmethod
  338. def get_tool_runtime_from_plugin(
  339. cls,
  340. tool_type: ToolProviderType,
  341. tenant_id: str,
  342. provider: str,
  343. tool_name: str,
  344. tool_parameters: dict[str, Any],
  345. ) -> Tool:
  346. """
  347. get tool runtime from plugin
  348. """
  349. tool_entity = cls.get_tool_runtime(
  350. provider_type=tool_type,
  351. provider_id=provider,
  352. tool_name=tool_name,
  353. tenant_id=tenant_id,
  354. invoke_from=InvokeFrom.SERVICE_API,
  355. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  356. )
  357. runtime_parameters = {}
  358. parameters = tool_entity.get_merged_runtime_parameters()
  359. for parameter in parameters:
  360. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  361. # save tool parameter to tool entity memory
  362. value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name))
  363. runtime_parameters[parameter.name] = value
  364. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  365. return tool_entity
  366. @classmethod
  367. def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
  368. """
  369. get the absolute path of the icon of the hardcoded provider
  370. :param provider: the name of the provider
  371. :return: the absolute path of the icon, the mime type of the icon
  372. """
  373. # get provider
  374. provider_controller = cls.get_hardcoded_provider(provider)
  375. absolute_path = path.join(
  376. path.dirname(path.realpath(__file__)),
  377. "builtin_tool",
  378. "providers",
  379. provider,
  380. "_assets",
  381. provider_controller.entity.identity.icon,
  382. )
  383. # check if the icon exists
  384. if not path.exists(absolute_path):
  385. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  386. # get the mime type
  387. mime_type, _ = mimetypes.guess_type(absolute_path)
  388. mime_type = mime_type or "application/octet-stream"
  389. return absolute_path, mime_type
  390. @classmethod
  391. def list_hardcoded_providers(cls):
  392. # use cache first
  393. if cls._builtin_providers_loaded:
  394. yield from list(cls._hardcoded_providers.values())
  395. return
  396. with cls._builtin_provider_lock:
  397. if cls._builtin_providers_loaded:
  398. yield from list(cls._hardcoded_providers.values())
  399. return
  400. yield from cls._list_hardcoded_providers()
  401. @classmethod
  402. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  403. """
  404. list all the plugin providers
  405. """
  406. manager = PluginToolManager()
  407. provider_entities = manager.fetch_tool_providers(tenant_id)
  408. return [
  409. PluginToolProviderController(
  410. entity=provider.declaration,
  411. plugin_id=provider.plugin_id,
  412. plugin_unique_identifier=provider.plugin_unique_identifier,
  413. tenant_id=tenant_id,
  414. )
  415. for provider in provider_entities
  416. ]
  417. @classmethod
  418. def list_builtin_providers(
  419. cls, tenant_id: str
  420. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  421. """
  422. list all the builtin providers
  423. """
  424. yield from cls.list_hardcoded_providers()
  425. # get plugin providers
  426. yield from cls.list_plugin_providers(tenant_id)
  427. @classmethod
  428. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  429. """
  430. list all the builtin providers
  431. """
  432. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  433. if provider_path.startswith("__"):
  434. continue
  435. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  436. if provider_path.startswith("__"):
  437. continue
  438. # init provider
  439. try:
  440. provider_class = load_single_subclass_from_source(
  441. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  442. script_path=path.join(
  443. path.dirname(path.realpath(__file__)),
  444. "builtin_tool",
  445. "providers",
  446. provider_path,
  447. f"{provider_path}.py",
  448. ),
  449. parent_type=BuiltinToolProviderController,
  450. )
  451. provider: BuiltinToolProviderController = provider_class()
  452. cls._hardcoded_providers[provider.entity.identity.name] = provider
  453. for tool in provider.get_tools():
  454. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  455. yield provider
  456. except Exception:
  457. logger.exception(f"load builtin provider {provider_path}")
  458. continue
  459. # set builtin providers loaded
  460. cls._builtin_providers_loaded = True
  461. @classmethod
  462. def load_hardcoded_providers_cache(cls):
  463. for _ in cls.list_hardcoded_providers():
  464. pass
  465. @classmethod
  466. def clear_hardcoded_providers_cache(cls):
  467. cls._hardcoded_providers = {}
  468. cls._builtin_providers_loaded = False
  469. @classmethod
  470. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  471. """
  472. get the tool label
  473. :param tool_name: the name of the tool
  474. :return: the label of the tool
  475. """
  476. if len(cls._builtin_tools_labels) == 0:
  477. # init the builtin providers
  478. cls.load_hardcoded_providers_cache()
  479. if tool_name not in cls._builtin_tools_labels:
  480. return None
  481. return cls._builtin_tools_labels[tool_name]
  482. @classmethod
  483. def list_providers_from_api(
  484. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  485. ) -> list[ToolProviderApiEntity]:
  486. result_providers: dict[str, ToolProviderApiEntity] = {}
  487. filters = []
  488. if not typ:
  489. filters.extend(["builtin", "api", "workflow", "mcp"])
  490. else:
  491. filters.append(typ)
  492. with db.session.no_autoflush:
  493. if "builtin" in filters:
  494. # get builtin providers
  495. builtin_providers = cls.list_builtin_providers(tenant_id)
  496. # get db builtin providers
  497. db_builtin_providers: list[BuiltinToolProvider] = (
  498. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  499. )
  500. # rewrite db_builtin_providers
  501. for db_provider in db_builtin_providers:
  502. tool_provider_id = str(ToolProviderID(db_provider.provider))
  503. db_provider.provider = tool_provider_id
  504. def find_db_builtin_provider(provider):
  505. return next((x for x in db_builtin_providers if x.provider == provider), None)
  506. # append builtin providers
  507. for provider in builtin_providers:
  508. # handle include, exclude
  509. if is_filtered(
  510. include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
  511. exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
  512. data=provider,
  513. name_func=lambda x: x.identity.name,
  514. ):
  515. continue
  516. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  517. provider_controller=provider,
  518. db_provider=find_db_builtin_provider(provider.entity.identity.name),
  519. decrypt_credentials=False,
  520. )
  521. if isinstance(provider, PluginToolProviderController):
  522. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  523. else:
  524. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  525. # get db api providers
  526. if "api" in filters:
  527. db_api_providers: list[ApiToolProvider] = (
  528. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  529. )
  530. api_provider_controllers: list[dict[str, Any]] = [
  531. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  532. for provider in db_api_providers
  533. ]
  534. # get labels
  535. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  536. for api_provider_controller in api_provider_controllers:
  537. user_provider = ToolTransformService.api_provider_to_user_provider(
  538. provider_controller=api_provider_controller["controller"],
  539. db_provider=api_provider_controller["provider"],
  540. decrypt_credentials=False,
  541. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  542. )
  543. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  544. if "workflow" in filters:
  545. # get workflow providers
  546. workflow_providers: list[WorkflowToolProvider] = (
  547. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  548. )
  549. workflow_provider_controllers: list[WorkflowToolProviderController] = []
  550. for workflow_provider in workflow_providers:
  551. try:
  552. workflow_provider_controllers.append(
  553. ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  554. )
  555. except Exception:
  556. # app has been deleted
  557. pass
  558. labels = ToolLabelManager.get_tools_labels(
  559. [cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
  560. )
  561. for provider_controller in workflow_provider_controllers:
  562. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  563. provider_controller=provider_controller,
  564. labels=labels.get(provider_controller.provider_id, []),
  565. )
  566. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  567. if "mcp" in filters:
  568. mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
  569. for mcp_provider in mcp_providers:
  570. result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
  571. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  572. @classmethod
  573. def get_api_provider_controller(
  574. cls, tenant_id: str, provider_id: str
  575. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  576. """
  577. get the api provider
  578. :param tenant_id: the id of the tenant
  579. :param provider_id: the id of the provider
  580. :return: the provider controller, the credentials
  581. """
  582. provider: ApiToolProvider | None = (
  583. db.session.query(ApiToolProvider)
  584. .filter(
  585. ApiToolProvider.id == provider_id,
  586. ApiToolProvider.tenant_id == tenant_id,
  587. )
  588. .first()
  589. )
  590. if provider is None:
  591. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  592. controller = ApiToolProviderController.from_db(
  593. provider,
  594. ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  595. )
  596. controller.load_bundled_tools(provider.tools)
  597. return controller, provider.credentials
  598. @classmethod
  599. def get_mcp_provider_controller(cls, tenant_id: str, provider_id: str) -> MCPToolProviderController:
  600. """
  601. get the api provider
  602. :param tenant_id: the id of the tenant
  603. :param provider_id: the id of the provider
  604. :return: the provider controller, the credentials
  605. """
  606. provider: MCPToolProvider | None = (
  607. db.session.query(MCPToolProvider)
  608. .filter(
  609. MCPToolProvider.server_identifier == provider_id,
  610. MCPToolProvider.tenant_id == tenant_id,
  611. )
  612. .first()
  613. )
  614. if provider is None:
  615. raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
  616. controller = MCPToolProviderController._from_db(provider)
  617. return controller
  618. @classmethod
  619. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  620. """
  621. get api provider
  622. """
  623. """
  624. get tool provider
  625. """
  626. provider_name = provider
  627. provider_obj: ApiToolProvider | None = (
  628. db.session.query(ApiToolProvider)
  629. .filter(
  630. ApiToolProvider.tenant_id == tenant_id,
  631. ApiToolProvider.name == provider,
  632. )
  633. .first()
  634. )
  635. if provider_obj is None:
  636. raise ValueError(f"you have not added provider {provider_name}")
  637. try:
  638. credentials = json.loads(provider_obj.credentials_str) or {}
  639. except Exception:
  640. credentials = {}
  641. # package tool provider controller
  642. controller = ApiToolProviderController.from_db(
  643. provider_obj,
  644. ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  645. )
  646. # init tool configuration
  647. tool_configuration = ProviderConfigEncrypter(
  648. tenant_id=tenant_id,
  649. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  650. provider_type=controller.provider_type.value,
  651. provider_identity=controller.entity.identity.name,
  652. )
  653. decrypted_credentials = tool_configuration.decrypt(credentials)
  654. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  655. try:
  656. icon = json.loads(provider_obj.icon)
  657. except Exception:
  658. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  659. # add tool labels
  660. labels = ToolLabelManager.get_tool_labels(controller)
  661. return cast(
  662. dict,
  663. jsonable_encoder(
  664. {
  665. "schema_type": provider_obj.schema_type,
  666. "schema": provider_obj.schema,
  667. "tools": provider_obj.tools,
  668. "icon": icon,
  669. "description": provider_obj.description,
  670. "credentials": masked_credentials,
  671. "privacy_policy": provider_obj.privacy_policy,
  672. "custom_disclaimer": provider_obj.custom_disclaimer,
  673. "labels": labels,
  674. }
  675. ),
  676. )
  677. @classmethod
  678. def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
  679. return str(
  680. URL(dify_config.CONSOLE_API_URL or "/")
  681. / "console"
  682. / "api"
  683. / "workspaces"
  684. / "current"
  685. / "tool-provider"
  686. / "builtin"
  687. / provider_id
  688. / "icon"
  689. )
  690. @classmethod
  691. def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
  692. return str(
  693. URL(dify_config.CONSOLE_API_URL or "/")
  694. / "console"
  695. / "api"
  696. / "workspaces"
  697. / "current"
  698. / "plugin"
  699. / "icon"
  700. % {"tenant_id": tenant_id, "filename": filename}
  701. )
  702. @classmethod
  703. def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  704. try:
  705. workflow_provider: WorkflowToolProvider | None = (
  706. db.session.query(WorkflowToolProvider)
  707. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  708. .first()
  709. )
  710. if workflow_provider is None:
  711. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  712. icon: dict = json.loads(workflow_provider.icon)
  713. return icon
  714. except Exception:
  715. return {"background": "#252525", "content": "\ud83d\ude01"}
  716. @classmethod
  717. def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  718. try:
  719. api_provider: ApiToolProvider | None = (
  720. db.session.query(ApiToolProvider)
  721. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  722. .first()
  723. )
  724. if api_provider is None:
  725. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  726. icon: dict = json.loads(api_provider.icon)
  727. return icon
  728. except Exception:
  729. return {"background": "#252525", "content": "\ud83d\ude01"}
  730. @classmethod
  731. def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
  732. try:
  733. mcp_provider: MCPToolProvider | None = (
  734. db.session.query(MCPToolProvider)
  735. .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
  736. .first()
  737. )
  738. if mcp_provider is None:
  739. raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
  740. return mcp_provider.provider_icon
  741. except Exception:
  742. return {"background": "#252525", "content": "\ud83d\ude01"}
  743. @classmethod
  744. def get_tool_icon(
  745. cls,
  746. tenant_id: str,
  747. provider_type: ToolProviderType,
  748. provider_id: str,
  749. ) -> Union[str, dict]:
  750. """
  751. get the tool icon
  752. :param tenant_id: the id of the tenant
  753. :param provider_type: the type of the provider
  754. :param provider_id: the id of the provider
  755. :return:
  756. """
  757. provider_type = provider_type
  758. provider_id = provider_id
  759. if provider_type == ToolProviderType.BUILT_IN:
  760. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  761. if isinstance(provider, PluginToolProviderController):
  762. try:
  763. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  764. except Exception:
  765. return {"background": "#252525", "content": "\ud83d\ude01"}
  766. return cls.generate_builtin_tool_icon_url(provider_id)
  767. elif provider_type == ToolProviderType.API:
  768. return cls.generate_api_tool_icon_url(tenant_id, provider_id)
  769. elif provider_type == ToolProviderType.WORKFLOW:
  770. return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
  771. elif provider_type == ToolProviderType.PLUGIN:
  772. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  773. if isinstance(provider, PluginToolProviderController):
  774. try:
  775. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  776. except Exception:
  777. return {"background": "#252525", "content": "\ud83d\ude01"}
  778. raise ValueError(f"plugin provider {provider_id} not found")
  779. elif provider_type == ToolProviderType.MCP:
  780. return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
  781. else:
  782. raise ValueError(f"provider type {provider_type} not found")
  783. @classmethod
  784. def _convert_tool_parameters_type(
  785. cls,
  786. parameters: list[ToolParameter],
  787. variable_pool: Optional[VariablePool],
  788. tool_configurations: dict[str, Any],
  789. typ: Literal["agent", "workflow", "tool"] = "workflow",
  790. ) -> dict[str, Any]:
  791. """
  792. Convert tool parameters type
  793. """
  794. from core.workflow.nodes.tool.entities import ToolNodeData
  795. from core.workflow.nodes.tool.exc import ToolParameterError
  796. runtime_parameters = {}
  797. for parameter in parameters:
  798. if (
  799. parameter.type
  800. in {
  801. ToolParameter.ToolParameterType.SYSTEM_FILES,
  802. ToolParameter.ToolParameterType.FILE,
  803. ToolParameter.ToolParameterType.FILES,
  804. }
  805. and parameter.required
  806. and typ == "agent"
  807. ):
  808. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  809. # save tool parameter to tool entity memory
  810. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  811. if variable_pool:
  812. config = tool_configurations.get(parameter.name, {})
  813. if not (config and isinstance(config, dict) and config.get("value") is not None):
  814. continue
  815. tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
  816. if tool_input.type == "variable":
  817. variable = variable_pool.get(tool_input.value)
  818. if variable is None:
  819. raise ToolParameterError(f"Variable {tool_input.value} does not exist")
  820. parameter_value = variable.value
  821. elif tool_input.type in {"mixed", "constant"}:
  822. segment_group = variable_pool.convert_template(str(tool_input.value))
  823. parameter_value = segment_group.text
  824. else:
  825. raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
  826. runtime_parameters[parameter.name] = parameter_value
  827. else:
  828. value = parameter.init_frontend_parameter(tool_configurations.get(parameter.name))
  829. runtime_parameters[parameter.name] = value
  830. return runtime_parameters
  831. ToolManager.load_hardcoded_providers_cache()