You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tool_manager.py 41KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031
  1. import json
  2. import logging
  3. import mimetypes
  4. import time
  5. from collections.abc import Generator, Mapping
  6. from os import listdir, path
  7. from threading import Lock
  8. from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
  9. import sqlalchemy as sa
  10. from pydantic import TypeAdapter
  11. from yarl import URL
  12. import contexts
  13. from core.helper.provider_cache import ToolProviderCredentialsCache
  14. from core.plugin.entities.plugin import ToolProviderID
  15. from core.plugin.impl.oauth import OAuthHandler
  16. from core.plugin.impl.tool import PluginToolManager
  17. from core.tools.__base.tool_provider import ToolProviderController
  18. from core.tools.__base.tool_runtime import ToolRuntime
  19. from core.tools.mcp_tool.provider import MCPToolProviderController
  20. from core.tools.mcp_tool.tool import MCPTool
  21. from core.tools.plugin_tool.provider import PluginToolProviderController
  22. from core.tools.plugin_tool.tool import PluginTool
  23. from core.tools.utils.uuid_utils import is_valid_uuid
  24. from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
  25. from core.workflow.entities.variable_pool import VariablePool
  26. from services.tools.mcp_tools_manage_service import MCPToolManageService
  27. if TYPE_CHECKING:
  28. from core.workflow.nodes.tool.entities import ToolEntity
  29. from configs import dify_config
  30. from core.agent.entities import AgentToolEntity
  31. from core.app.entities.app_invoke_entities import InvokeFrom
  32. from core.helper.module_import_helper import load_single_subclass_from_source
  33. from core.helper.position_helper import is_filtered
  34. from core.model_runtime.utils.encoders import jsonable_encoder
  35. from core.tools.__base.tool import Tool
  36. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  37. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  38. from core.tools.builtin_tool.tool import BuiltinTool
  39. from core.tools.custom_tool.provider import ApiToolProviderController
  40. from core.tools.custom_tool.tool import ApiTool
  41. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  42. from core.tools.entities.common_entities import I18nObject
  43. from core.tools.entities.tool_entities import (
  44. ApiProviderAuthType,
  45. CredentialType,
  46. ToolInvokeFrom,
  47. ToolParameter,
  48. ToolProviderType,
  49. )
  50. from core.tools.errors import ToolProviderNotFoundError
  51. from core.tools.tool_label_manager import ToolLabelManager
  52. from core.tools.utils.configuration import (
  53. ToolParameterConfigurationManager,
  54. )
  55. from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
  56. from core.tools.workflow_as_tool.tool import WorkflowTool
  57. from extensions.ext_database import db
  58. from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
  59. from services.tools.tools_transform_service import ToolTransformService
  60. logger = logging.getLogger(__name__)
  61. class ToolManager:
  62. _builtin_provider_lock = Lock()
  63. _hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
  64. _builtin_providers_loaded = False
  65. _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
  66. @classmethod
  67. def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
  68. """
  69. get the hardcoded provider
  70. """
  71. if len(cls._hardcoded_providers) == 0:
  72. # init the builtin providers
  73. cls.load_hardcoded_providers_cache()
  74. return cls._hardcoded_providers[provider]
  75. @classmethod
  76. def get_builtin_provider(
  77. cls, provider: str, tenant_id: str
  78. ) -> BuiltinToolProviderController | PluginToolProviderController:
  79. """
  80. get the builtin provider
  81. :param provider: the name of the provider
  82. :param tenant_id: the id of the tenant
  83. :return: the provider
  84. """
  85. # split provider to
  86. if len(cls._hardcoded_providers) == 0:
  87. # init the builtin providers
  88. cls.load_hardcoded_providers_cache()
  89. if provider not in cls._hardcoded_providers:
  90. # get plugin provider
  91. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  92. if plugin_provider:
  93. return plugin_provider
  94. return cls._hardcoded_providers[provider]
  95. @classmethod
  96. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  97. """
  98. get the plugin provider
  99. """
  100. # check if context is set
  101. try:
  102. contexts.plugin_tool_providers.get()
  103. except LookupError:
  104. contexts.plugin_tool_providers.set({})
  105. contexts.plugin_tool_providers_lock.set(Lock())
  106. plugin_tool_providers = contexts.plugin_tool_providers.get()
  107. if provider in plugin_tool_providers:
  108. return plugin_tool_providers[provider]
  109. with contexts.plugin_tool_providers_lock.get():
  110. # double check
  111. plugin_tool_providers = contexts.plugin_tool_providers.get()
  112. if provider in plugin_tool_providers:
  113. return plugin_tool_providers[provider]
  114. manager = PluginToolManager()
  115. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  116. if not provider_entity:
  117. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  118. controller = PluginToolProviderController(
  119. entity=provider_entity.declaration,
  120. plugin_id=provider_entity.plugin_id,
  121. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  122. tenant_id=tenant_id,
  123. )
  124. plugin_tool_providers[provider] = controller
  125. return controller
  126. @classmethod
  127. def get_tool_runtime(
  128. cls,
  129. provider_type: ToolProviderType,
  130. provider_id: str,
  131. tool_name: str,
  132. tenant_id: str,
  133. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  134. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  135. credential_id: Optional[str] = None,
  136. ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
  137. """
  138. get the tool runtime
  139. :param provider_type: the type of the provider
  140. :param provider_id: the id of the provider
  141. :param tool_name: the name of the tool
  142. :param tenant_id: the tenant id
  143. :param invoke_from: invoke from
  144. :param tool_invoke_from: the tool invoke from
  145. :param credential_id: the credential id
  146. :return: the tool
  147. """
  148. if provider_type == ToolProviderType.BUILT_IN:
  149. # check if the builtin tool need credentials
  150. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  151. builtin_tool = provider_controller.get_tool(tool_name)
  152. if not builtin_tool:
  153. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  154. if not provider_controller.need_credentials:
  155. return cast(
  156. BuiltinTool,
  157. builtin_tool.fork_tool_runtime(
  158. runtime=ToolRuntime(
  159. tenant_id=tenant_id,
  160. credentials={},
  161. invoke_from=invoke_from,
  162. tool_invoke_from=tool_invoke_from,
  163. )
  164. ),
  165. )
  166. builtin_provider = None
  167. if isinstance(provider_controller, PluginToolProviderController):
  168. provider_id_entity = ToolProviderID(provider_id)
  169. # get specific credentials
  170. if is_valid_uuid(credential_id):
  171. try:
  172. builtin_provider = (
  173. db.session.query(BuiltinToolProvider)
  174. .where(
  175. BuiltinToolProvider.tenant_id == tenant_id,
  176. BuiltinToolProvider.id == credential_id,
  177. )
  178. .first()
  179. )
  180. except Exception as e:
  181. builtin_provider = None
  182. logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
  183. # if the provider has been deleted, raise an error
  184. if builtin_provider is None:
  185. raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
  186. # fallback to the default provider
  187. if builtin_provider is None:
  188. # use the default provider
  189. builtin_provider = (
  190. db.session.query(BuiltinToolProvider)
  191. .where(
  192. BuiltinToolProvider.tenant_id == tenant_id,
  193. (BuiltinToolProvider.provider == str(provider_id_entity))
  194. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  195. )
  196. .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
  197. .first()
  198. )
  199. if builtin_provider is None:
  200. raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
  201. else:
  202. builtin_provider = (
  203. db.session.query(BuiltinToolProvider)
  204. .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  205. .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
  206. .first()
  207. )
  208. if builtin_provider is None:
  209. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  210. encrypter, cache = create_provider_encrypter(
  211. tenant_id=tenant_id,
  212. config=[
  213. x.to_basic_provider_config()
  214. for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
  215. ],
  216. cache=ToolProviderCredentialsCache(
  217. tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
  218. ),
  219. )
  220. # decrypt the credentials
  221. decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
  222. # check if the credentials is expired
  223. if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
  224. # TODO: circular import
  225. from services.tools.builtin_tools_manage_service import BuiltinToolManageService
  226. # refresh the credentials
  227. tool_provider = ToolProviderID(provider_id)
  228. provider_name = tool_provider.provider_name
  229. redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
  230. system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
  231. oauth_handler = OAuthHandler()
  232. # refresh the credentials
  233. refreshed_credentials = oauth_handler.refresh_credentials(
  234. tenant_id=tenant_id,
  235. user_id=builtin_provider.user_id,
  236. plugin_id=tool_provider.plugin_id,
  237. provider=provider_name,
  238. redirect_uri=redirect_uri,
  239. system_credentials=system_credentials or {},
  240. credentials=decrypted_credentials,
  241. )
  242. # update the credentials
  243. builtin_provider.encrypted_credentials = (
  244. TypeAdapter(dict[str, Any])
  245. .dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
  246. .decode("utf-8")
  247. )
  248. builtin_provider.expires_at = refreshed_credentials.expires_at
  249. db.session.commit()
  250. decrypted_credentials = refreshed_credentials.credentials
  251. cache.delete()
  252. return cast(
  253. BuiltinTool,
  254. builtin_tool.fork_tool_runtime(
  255. runtime=ToolRuntime(
  256. tenant_id=tenant_id,
  257. credentials=dict(decrypted_credentials),
  258. credential_type=CredentialType.of(builtin_provider.credential_type),
  259. runtime_parameters={},
  260. invoke_from=invoke_from,
  261. tool_invoke_from=tool_invoke_from,
  262. )
  263. ),
  264. )
  265. elif provider_type == ToolProviderType.API:
  266. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  267. encrypter, _ = create_tool_provider_encrypter(
  268. tenant_id=tenant_id,
  269. controller=api_provider,
  270. )
  271. return cast(
  272. ApiTool,
  273. api_provider.get_tool(tool_name).fork_tool_runtime(
  274. runtime=ToolRuntime(
  275. tenant_id=tenant_id,
  276. credentials=encrypter.decrypt(credentials),
  277. invoke_from=invoke_from,
  278. tool_invoke_from=tool_invoke_from,
  279. )
  280. ),
  281. )
  282. elif provider_type == ToolProviderType.WORKFLOW:
  283. workflow_provider = (
  284. db.session.query(WorkflowToolProvider)
  285. .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  286. .first()
  287. )
  288. if workflow_provider is None:
  289. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  290. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  291. controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
  292. if controller_tools is None or len(controller_tools) == 0:
  293. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  294. return cast(
  295. WorkflowTool,
  296. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  297. runtime=ToolRuntime(
  298. tenant_id=tenant_id,
  299. credentials={},
  300. invoke_from=invoke_from,
  301. tool_invoke_from=tool_invoke_from,
  302. )
  303. ),
  304. )
  305. elif provider_type == ToolProviderType.APP:
  306. raise NotImplementedError("app provider not implemented")
  307. elif provider_type == ToolProviderType.PLUGIN:
  308. return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
  309. elif provider_type == ToolProviderType.MCP:
  310. return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
  311. else:
  312. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  313. @classmethod
  314. def get_agent_tool_runtime(
  315. cls,
  316. tenant_id: str,
  317. app_id: str,
  318. agent_tool: AgentToolEntity,
  319. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  320. variable_pool: Optional[VariablePool] = None,
  321. ) -> Tool:
  322. """
  323. get the agent tool runtime
  324. """
  325. tool_entity = cls.get_tool_runtime(
  326. provider_type=agent_tool.provider_type,
  327. provider_id=agent_tool.provider_id,
  328. tool_name=agent_tool.tool_name,
  329. tenant_id=tenant_id,
  330. invoke_from=invoke_from,
  331. tool_invoke_from=ToolInvokeFrom.AGENT,
  332. credential_id=agent_tool.credential_id,
  333. )
  334. runtime_parameters = {}
  335. parameters = tool_entity.get_merged_runtime_parameters()
  336. runtime_parameters = cls._convert_tool_parameters_type(
  337. parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
  338. )
  339. # decrypt runtime parameters
  340. encryption_manager = ToolParameterConfigurationManager(
  341. tenant_id=tenant_id,
  342. tool_runtime=tool_entity,
  343. provider_name=agent_tool.provider_id,
  344. provider_type=agent_tool.provider_type,
  345. identity_id=f"AGENT.{app_id}",
  346. )
  347. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  348. if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
  349. raise ValueError("runtime not found or runtime parameters not found")
  350. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  351. return tool_entity
  352. @classmethod
  353. def get_workflow_tool_runtime(
  354. cls,
  355. tenant_id: str,
  356. app_id: str,
  357. node_id: str,
  358. workflow_tool: "ToolEntity",
  359. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  360. variable_pool: Optional[VariablePool] = None,
  361. ) -> Tool:
  362. """
  363. get the workflow tool runtime
  364. """
  365. tool_runtime = cls.get_tool_runtime(
  366. provider_type=workflow_tool.provider_type,
  367. provider_id=workflow_tool.provider_id,
  368. tool_name=workflow_tool.tool_name,
  369. tenant_id=tenant_id,
  370. invoke_from=invoke_from,
  371. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  372. credential_id=workflow_tool.credential_id,
  373. )
  374. parameters = tool_runtime.get_merged_runtime_parameters()
  375. runtime_parameters = cls._convert_tool_parameters_type(
  376. parameters, variable_pool, workflow_tool.tool_configurations, typ="workflow"
  377. )
  378. # decrypt runtime parameters
  379. encryption_manager = ToolParameterConfigurationManager(
  380. tenant_id=tenant_id,
  381. tool_runtime=tool_runtime,
  382. provider_name=workflow_tool.provider_id,
  383. provider_type=workflow_tool.provider_type,
  384. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  385. )
  386. if runtime_parameters:
  387. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  388. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  389. return tool_runtime
  390. @classmethod
  391. def get_tool_runtime_from_plugin(
  392. cls,
  393. tool_type: ToolProviderType,
  394. tenant_id: str,
  395. provider: str,
  396. tool_name: str,
  397. tool_parameters: dict[str, Any],
  398. credential_id: Optional[str] = None,
  399. ) -> Tool:
  400. """
  401. get tool runtime from plugin
  402. """
  403. tool_entity = cls.get_tool_runtime(
  404. provider_type=tool_type,
  405. provider_id=provider,
  406. tool_name=tool_name,
  407. tenant_id=tenant_id,
  408. invoke_from=InvokeFrom.SERVICE_API,
  409. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  410. credential_id=credential_id,
  411. )
  412. runtime_parameters = {}
  413. parameters = tool_entity.get_merged_runtime_parameters()
  414. for parameter in parameters:
  415. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  416. # save tool parameter to tool entity memory
  417. value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name))
  418. runtime_parameters[parameter.name] = value
  419. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  420. return tool_entity
  421. @classmethod
  422. def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
  423. """
  424. get the absolute path of the icon of the hardcoded provider
  425. :param provider: the name of the provider
  426. :return: the absolute path of the icon, the mime type of the icon
  427. """
  428. # get provider
  429. provider_controller = cls.get_hardcoded_provider(provider)
  430. absolute_path = path.join(
  431. path.dirname(path.realpath(__file__)),
  432. "builtin_tool",
  433. "providers",
  434. provider,
  435. "_assets",
  436. provider_controller.entity.identity.icon,
  437. )
  438. # check if the icon exists
  439. if not path.exists(absolute_path):
  440. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  441. # get the mime type
  442. mime_type, _ = mimetypes.guess_type(absolute_path)
  443. mime_type = mime_type or "application/octet-stream"
  444. return absolute_path, mime_type
  445. @classmethod
  446. def list_hardcoded_providers(cls):
  447. # use cache first
  448. if cls._builtin_providers_loaded:
  449. yield from list(cls._hardcoded_providers.values())
  450. return
  451. with cls._builtin_provider_lock:
  452. if cls._builtin_providers_loaded:
  453. yield from list(cls._hardcoded_providers.values())
  454. return
  455. yield from cls._list_hardcoded_providers()
  456. @classmethod
  457. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  458. """
  459. list all the plugin providers
  460. """
  461. manager = PluginToolManager()
  462. provider_entities = manager.fetch_tool_providers(tenant_id)
  463. return [
  464. PluginToolProviderController(
  465. entity=provider.declaration,
  466. plugin_id=provider.plugin_id,
  467. plugin_unique_identifier=provider.plugin_unique_identifier,
  468. tenant_id=tenant_id,
  469. )
  470. for provider in provider_entities
  471. ]
  472. @classmethod
  473. def list_builtin_providers(
  474. cls, tenant_id: str
  475. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  476. """
  477. list all the builtin providers
  478. """
  479. yield from cls.list_hardcoded_providers()
  480. # get plugin providers
  481. yield from cls.list_plugin_providers(tenant_id)
  482. @classmethod
  483. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  484. """
  485. list all the builtin providers
  486. """
  487. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  488. if provider_path.startswith("__"):
  489. continue
  490. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  491. if provider_path.startswith("__"):
  492. continue
  493. # init provider
  494. try:
  495. provider_class = load_single_subclass_from_source(
  496. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  497. script_path=path.join(
  498. path.dirname(path.realpath(__file__)),
  499. "builtin_tool",
  500. "providers",
  501. provider_path,
  502. f"{provider_path}.py",
  503. ),
  504. parent_type=BuiltinToolProviderController,
  505. )
  506. provider: BuiltinToolProviderController = provider_class()
  507. cls._hardcoded_providers[provider.entity.identity.name] = provider
  508. for tool in provider.get_tools():
  509. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  510. yield provider
  511. except Exception:
  512. logger.exception("load builtin provider %s", provider_path)
  513. continue
  514. # set builtin providers loaded
  515. cls._builtin_providers_loaded = True
  516. @classmethod
  517. def load_hardcoded_providers_cache(cls):
  518. for _ in cls.list_hardcoded_providers():
  519. pass
  520. @classmethod
  521. def clear_hardcoded_providers_cache(cls):
  522. cls._hardcoded_providers = {}
  523. cls._builtin_providers_loaded = False
  524. @classmethod
  525. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  526. """
  527. get the tool label
  528. :param tool_name: the name of the tool
  529. :return: the label of the tool
  530. """
  531. if len(cls._builtin_tools_labels) == 0:
  532. # init the builtin providers
  533. cls.load_hardcoded_providers_cache()
  534. if tool_name not in cls._builtin_tools_labels:
  535. return None
  536. return cls._builtin_tools_labels[tool_name]
  537. @classmethod
  538. def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
  539. """
  540. list all the builtin providers
  541. """
  542. # according to multi credentials, select the one with is_default=True first, then created_at oldest
  543. # for compatibility with old version
  544. sql = """
  545. SELECT DISTINCT ON (tenant_id, provider) id
  546. FROM tool_builtin_providers
  547. WHERE tenant_id = :tenant_id
  548. ORDER BY tenant_id, provider, is_default DESC, created_at DESC
  549. """
  550. ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
  551. return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
  552. @classmethod
  553. def list_providers_from_api(
  554. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  555. ) -> list[ToolProviderApiEntity]:
  556. result_providers: dict[str, ToolProviderApiEntity] = {}
  557. filters = []
  558. if not typ:
  559. filters.extend(["builtin", "api", "workflow", "mcp"])
  560. else:
  561. filters.append(typ)
  562. with db.session.no_autoflush:
  563. if "builtin" in filters:
  564. builtin_providers = cls.list_builtin_providers(tenant_id)
  565. # key: provider name, value: provider
  566. db_builtin_providers = {
  567. str(ToolProviderID(provider.provider)): provider
  568. for provider in cls.list_default_builtin_providers(tenant_id)
  569. }
  570. # append builtin providers
  571. for provider in builtin_providers:
  572. # handle include, exclude
  573. if is_filtered(
  574. include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
  575. exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
  576. data=provider,
  577. name_func=lambda x: x.identity.name,
  578. ):
  579. continue
  580. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  581. provider_controller=provider,
  582. db_provider=db_builtin_providers.get(provider.entity.identity.name),
  583. decrypt_credentials=False,
  584. )
  585. if isinstance(provider, PluginToolProviderController):
  586. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  587. else:
  588. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  589. # get db api providers
  590. if "api" in filters:
  591. db_api_providers: list[ApiToolProvider] = (
  592. db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
  593. )
  594. api_provider_controllers: list[dict[str, Any]] = [
  595. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  596. for provider in db_api_providers
  597. ]
  598. # get labels
  599. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  600. for api_provider_controller in api_provider_controllers:
  601. user_provider = ToolTransformService.api_provider_to_user_provider(
  602. provider_controller=api_provider_controller["controller"],
  603. db_provider=api_provider_controller["provider"],
  604. decrypt_credentials=False,
  605. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  606. )
  607. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  608. if "workflow" in filters:
  609. # get workflow providers
  610. workflow_providers: list[WorkflowToolProvider] = (
  611. db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
  612. )
  613. workflow_provider_controllers: list[WorkflowToolProviderController] = []
  614. for workflow_provider in workflow_providers:
  615. try:
  616. workflow_provider_controllers.append(
  617. ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  618. )
  619. except Exception:
  620. # app has been deleted
  621. pass
  622. labels = ToolLabelManager.get_tools_labels(
  623. [cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
  624. )
  625. for provider_controller in workflow_provider_controllers:
  626. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  627. provider_controller=provider_controller,
  628. labels=labels.get(provider_controller.provider_id, []),
  629. )
  630. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  631. if "mcp" in filters:
  632. mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
  633. for mcp_provider in mcp_providers:
  634. result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
  635. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  636. @classmethod
  637. def get_api_provider_controller(
  638. cls, tenant_id: str, provider_id: str
  639. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  640. """
  641. get the api provider
  642. :param tenant_id: the id of the tenant
  643. :param provider_id: the id of the provider
  644. :return: the provider controller, the credentials
  645. """
  646. provider: ApiToolProvider | None = (
  647. db.session.query(ApiToolProvider)
  648. .where(
  649. ApiToolProvider.id == provider_id,
  650. ApiToolProvider.tenant_id == tenant_id,
  651. )
  652. .first()
  653. )
  654. if provider is None:
  655. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  656. auth_type = ApiProviderAuthType.NONE
  657. provider_auth_type = provider.credentials.get("auth_type")
  658. if provider_auth_type in ("api_key_header", "api_key"): # backward compatibility
  659. auth_type = ApiProviderAuthType.API_KEY_HEADER
  660. elif provider_auth_type == "api_key_query":
  661. auth_type = ApiProviderAuthType.API_KEY_QUERY
  662. controller = ApiToolProviderController.from_db(
  663. provider,
  664. auth_type,
  665. )
  666. controller.load_bundled_tools(provider.tools)
  667. return controller, provider.credentials
  668. @classmethod
  669. def get_mcp_provider_controller(cls, tenant_id: str, provider_id: str) -> MCPToolProviderController:
  670. """
  671. get the api provider
  672. :param tenant_id: the id of the tenant
  673. :param provider_id: the id of the provider
  674. :return: the provider controller, the credentials
  675. """
  676. provider: MCPToolProvider | None = (
  677. db.session.query(MCPToolProvider)
  678. .where(
  679. MCPToolProvider.server_identifier == provider_id,
  680. MCPToolProvider.tenant_id == tenant_id,
  681. )
  682. .first()
  683. )
  684. if provider is None:
  685. raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
  686. controller = MCPToolProviderController._from_db(provider)
  687. return controller
  688. @classmethod
  689. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  690. """
  691. get api provider
  692. """
  693. """
  694. get tool provider
  695. """
  696. provider_name = provider
  697. provider_obj: ApiToolProvider | None = (
  698. db.session.query(ApiToolProvider)
  699. .where(
  700. ApiToolProvider.tenant_id == tenant_id,
  701. ApiToolProvider.name == provider,
  702. )
  703. .first()
  704. )
  705. if provider_obj is None:
  706. raise ValueError(f"you have not added provider {provider_name}")
  707. try:
  708. credentials = json.loads(provider_obj.credentials_str) or {}
  709. except Exception:
  710. credentials = {}
  711. # package tool provider controller
  712. auth_type = ApiProviderAuthType.NONE
  713. credentials_auth_type = credentials.get("auth_type")
  714. if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
  715. auth_type = ApiProviderAuthType.API_KEY_HEADER
  716. elif credentials_auth_type == "api_key_query":
  717. auth_type = ApiProviderAuthType.API_KEY_QUERY
  718. controller = ApiToolProviderController.from_db(
  719. provider_obj,
  720. auth_type,
  721. )
  722. # init tool configuration
  723. encrypter, _ = create_tool_provider_encrypter(
  724. tenant_id=tenant_id,
  725. controller=controller,
  726. )
  727. masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
  728. try:
  729. icon = json.loads(provider_obj.icon)
  730. except Exception:
  731. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  732. # add tool labels
  733. labels = ToolLabelManager.get_tool_labels(controller)
  734. return cast(
  735. dict,
  736. jsonable_encoder(
  737. {
  738. "schema_type": provider_obj.schema_type,
  739. "schema": provider_obj.schema,
  740. "tools": provider_obj.tools,
  741. "icon": icon,
  742. "description": provider_obj.description,
  743. "credentials": masked_credentials,
  744. "privacy_policy": provider_obj.privacy_policy,
  745. "custom_disclaimer": provider_obj.custom_disclaimer,
  746. "labels": labels,
  747. }
  748. ),
  749. )
  750. @classmethod
  751. def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
  752. return str(
  753. URL(dify_config.CONSOLE_API_URL or "/")
  754. / "console"
  755. / "api"
  756. / "workspaces"
  757. / "current"
  758. / "tool-provider"
  759. / "builtin"
  760. / provider_id
  761. / "icon"
  762. )
  763. @classmethod
  764. def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
  765. return str(
  766. URL(dify_config.CONSOLE_API_URL or "/")
  767. / "console"
  768. / "api"
  769. / "workspaces"
  770. / "current"
  771. / "plugin"
  772. / "icon"
  773. % {"tenant_id": tenant_id, "filename": filename}
  774. )
  775. @classmethod
  776. def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  777. try:
  778. workflow_provider: WorkflowToolProvider | None = (
  779. db.session.query(WorkflowToolProvider)
  780. .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  781. .first()
  782. )
  783. if workflow_provider is None:
  784. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  785. icon: dict = json.loads(workflow_provider.icon)
  786. return icon
  787. except Exception:
  788. return {"background": "#252525", "content": "\ud83d\ude01"}
  789. @classmethod
  790. def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  791. try:
  792. api_provider: ApiToolProvider | None = (
  793. db.session.query(ApiToolProvider)
  794. .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  795. .first()
  796. )
  797. if api_provider is None:
  798. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  799. icon: dict = json.loads(api_provider.icon)
  800. return icon
  801. except Exception:
  802. return {"background": "#252525", "content": "\ud83d\ude01"}
  803. @classmethod
  804. def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
  805. try:
  806. mcp_provider: MCPToolProvider | None = (
  807. db.session.query(MCPToolProvider)
  808. .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
  809. .first()
  810. )
  811. if mcp_provider is None:
  812. raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
  813. return mcp_provider.provider_icon
  814. except Exception:
  815. return {"background": "#252525", "content": "\ud83d\ude01"}
  816. @classmethod
  817. def get_tool_icon(
  818. cls,
  819. tenant_id: str,
  820. provider_type: ToolProviderType,
  821. provider_id: str,
  822. ) -> Union[str, dict]:
  823. """
  824. get the tool icon
  825. :param tenant_id: the id of the tenant
  826. :param provider_type: the type of the provider
  827. :param provider_id: the id of the provider
  828. :return:
  829. """
  830. provider_type = provider_type
  831. provider_id = provider_id
  832. if provider_type == ToolProviderType.BUILT_IN:
  833. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  834. if isinstance(provider, PluginToolProviderController):
  835. try:
  836. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  837. except Exception:
  838. return {"background": "#252525", "content": "\ud83d\ude01"}
  839. return cls.generate_builtin_tool_icon_url(provider_id)
  840. elif provider_type == ToolProviderType.API:
  841. return cls.generate_api_tool_icon_url(tenant_id, provider_id)
  842. elif provider_type == ToolProviderType.WORKFLOW:
  843. return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
  844. elif provider_type == ToolProviderType.PLUGIN:
  845. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  846. if isinstance(provider, PluginToolProviderController):
  847. try:
  848. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  849. except Exception:
  850. return {"background": "#252525", "content": "\ud83d\ude01"}
  851. raise ValueError(f"plugin provider {provider_id} not found")
  852. elif provider_type == ToolProviderType.MCP:
  853. return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
  854. else:
  855. raise ValueError(f"provider type {provider_type} not found")
  856. @classmethod
  857. def _convert_tool_parameters_type(
  858. cls,
  859. parameters: list[ToolParameter],
  860. variable_pool: Optional[VariablePool],
  861. tool_configurations: dict[str, Any],
  862. typ: Literal["agent", "workflow", "tool"] = "workflow",
  863. ) -> dict[str, Any]:
  864. """
  865. Convert tool parameters type
  866. """
  867. from core.workflow.nodes.tool.entities import ToolNodeData
  868. from core.workflow.nodes.tool.exc import ToolParameterError
  869. runtime_parameters = {}
  870. for parameter in parameters:
  871. if (
  872. parameter.type
  873. in {
  874. ToolParameter.ToolParameterType.SYSTEM_FILES,
  875. ToolParameter.ToolParameterType.FILE,
  876. ToolParameter.ToolParameterType.FILES,
  877. }
  878. and parameter.required
  879. and typ == "agent"
  880. ):
  881. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  882. # save tool parameter to tool entity memory
  883. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  884. if variable_pool:
  885. config = tool_configurations.get(parameter.name, {})
  886. if not (config and isinstance(config, dict) and config.get("value") is not None):
  887. continue
  888. tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
  889. if tool_input.type == "variable":
  890. variable = variable_pool.get(tool_input.value)
  891. if variable is None:
  892. raise ToolParameterError(f"Variable {tool_input.value} does not exist")
  893. parameter_value = variable.value
  894. elif tool_input.type == "constant":
  895. parameter_value = tool_input.value
  896. elif tool_input.type == "mixed":
  897. segment_group = variable_pool.convert_template(str(tool_input.value))
  898. parameter_value = segment_group.text
  899. else:
  900. raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
  901. runtime_parameters[parameter.name] = parameter_value
  902. else:
  903. value = parameter.init_frontend_parameter(tool_configurations.get(parameter.name))
  904. runtime_parameters[parameter.name] = value
  905. return runtime_parameters
  906. ToolManager.load_hardcoded_providers_cache()