Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

tool_manager.py 39KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989
  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
  8. from yarl import URL
  9. import contexts
  10. from core.helper.provider_cache import ToolProviderCredentialsCache
  11. from core.plugin.entities.plugin import ToolProviderID
  12. from core.plugin.impl.tool import PluginToolManager
  13. from core.tools.__base.tool_provider import ToolProviderController
  14. from core.tools.__base.tool_runtime import ToolRuntime
  15. from core.tools.mcp_tool.provider import MCPToolProviderController
  16. from core.tools.mcp_tool.tool import MCPTool
  17. from core.tools.plugin_tool.provider import PluginToolProviderController
  18. from core.tools.plugin_tool.tool import PluginTool
  19. from core.tools.utils.uuid_utils import is_valid_uuid
  20. from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
  21. from core.workflow.entities.variable_pool import VariablePool
  22. from services.tools.mcp_tools_mange_service import MCPToolManageService
  23. if TYPE_CHECKING:
  24. from core.workflow.nodes.tool.entities import ToolEntity
  25. from configs import dify_config
  26. from core.agent.entities import AgentToolEntity
  27. from core.app.entities.app_invoke_entities import InvokeFrom
  28. from core.helper.module_import_helper import load_single_subclass_from_source
  29. from core.helper.position_helper import is_filtered
  30. from core.model_runtime.utils.encoders import jsonable_encoder
  31. from core.tools.__base.tool import Tool
  32. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  33. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  34. from core.tools.builtin_tool.tool import BuiltinTool
  35. from core.tools.custom_tool.provider import ApiToolProviderController
  36. from core.tools.custom_tool.tool import ApiTool
  37. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  38. from core.tools.entities.common_entities import I18nObject
  39. from core.tools.entities.tool_entities import (
  40. ApiProviderAuthType,
  41. CredentialType,
  42. ToolInvokeFrom,
  43. ToolParameter,
  44. ToolProviderType,
  45. )
  46. from core.tools.errors import ToolProviderNotFoundError
  47. from core.tools.tool_label_manager import ToolLabelManager
  48. from core.tools.utils.configuration import (
  49. ToolParameterConfigurationManager,
  50. )
  51. from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
  52. from core.tools.workflow_as_tool.tool import WorkflowTool
  53. from extensions.ext_database import db
  54. from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
  55. from services.tools.tools_transform_service import ToolTransformService
  56. logger = logging.getLogger(__name__)
  57. class ToolManager:
  58. _builtin_provider_lock = Lock()
  59. _hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
  60. _builtin_providers_loaded = False
  61. _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
  62. @classmethod
  63. def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
  64. """
  65. get the hardcoded provider
  66. """
  67. if len(cls._hardcoded_providers) == 0:
  68. # init the builtin providers
  69. cls.load_hardcoded_providers_cache()
  70. return cls._hardcoded_providers[provider]
  71. @classmethod
  72. def get_builtin_provider(
  73. cls, provider: str, tenant_id: str
  74. ) -> BuiltinToolProviderController | PluginToolProviderController:
  75. """
  76. get the builtin provider
  77. :param provider: the name of the provider
  78. :param tenant_id: the id of the tenant
  79. :return: the provider
  80. """
  81. # split provider to
  82. if len(cls._hardcoded_providers) == 0:
  83. # init the builtin providers
  84. cls.load_hardcoded_providers_cache()
  85. if provider not in cls._hardcoded_providers:
  86. # get plugin provider
  87. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  88. if plugin_provider:
  89. return plugin_provider
  90. return cls._hardcoded_providers[provider]
  91. @classmethod
  92. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  93. """
  94. get the plugin provider
  95. """
  96. # check if context is set
  97. try:
  98. contexts.plugin_tool_providers.get()
  99. except LookupError:
  100. contexts.plugin_tool_providers.set({})
  101. contexts.plugin_tool_providers_lock.set(Lock())
  102. plugin_tool_providers = contexts.plugin_tool_providers.get()
  103. if provider in plugin_tool_providers:
  104. return plugin_tool_providers[provider]
  105. with contexts.plugin_tool_providers_lock.get():
  106. # double check
  107. plugin_tool_providers = contexts.plugin_tool_providers.get()
  108. if provider in plugin_tool_providers:
  109. return plugin_tool_providers[provider]
  110. manager = PluginToolManager()
  111. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  112. if not provider_entity:
  113. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  114. controller = PluginToolProviderController(
  115. entity=provider_entity.declaration,
  116. plugin_id=provider_entity.plugin_id,
  117. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  118. tenant_id=tenant_id,
  119. )
  120. plugin_tool_providers[provider] = controller
  121. return controller
  122. @classmethod
  123. def get_tool_runtime(
  124. cls,
  125. provider_type: ToolProviderType,
  126. provider_id: str,
  127. tool_name: str,
  128. tenant_id: str,
  129. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  130. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  131. credential_id: Optional[str] = None,
  132. ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
  133. """
  134. get the tool runtime
  135. :param provider_type: the type of the provider
  136. :param provider_id: the id of the provider
  137. :param tool_name: the name of the tool
  138. :param tenant_id: the tenant id
  139. :param invoke_from: invoke from
  140. :param tool_invoke_from: the tool invoke from
  141. :param credential_id: the credential id
  142. :return: the tool
  143. """
  144. if provider_type == ToolProviderType.BUILT_IN:
  145. # check if the builtin tool need credentials
  146. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  147. builtin_tool = provider_controller.get_tool(tool_name)
  148. if not builtin_tool:
  149. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  150. if not provider_controller.need_credentials:
  151. return cast(
  152. BuiltinTool,
  153. builtin_tool.fork_tool_runtime(
  154. runtime=ToolRuntime(
  155. tenant_id=tenant_id,
  156. credentials={},
  157. invoke_from=invoke_from,
  158. tool_invoke_from=tool_invoke_from,
  159. )
  160. ),
  161. )
  162. builtin_provider = None
  163. if isinstance(provider_controller, PluginToolProviderController):
  164. provider_id_entity = ToolProviderID(provider_id)
  165. # get specific credentials
  166. if is_valid_uuid(credential_id):
  167. try:
  168. builtin_provider = (
  169. db.session.query(BuiltinToolProvider)
  170. .filter(
  171. BuiltinToolProvider.tenant_id == tenant_id,
  172. BuiltinToolProvider.id == credential_id,
  173. )
  174. .first()
  175. )
  176. except Exception as e:
  177. builtin_provider = None
  178. logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True)
  179. # if the provider has been deleted, raise an error
  180. if builtin_provider is None:
  181. raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
  182. # fallback to the default provider
  183. if builtin_provider is None:
  184. # use the default provider
  185. builtin_provider = (
  186. db.session.query(BuiltinToolProvider)
  187. .filter(
  188. BuiltinToolProvider.tenant_id == tenant_id,
  189. (BuiltinToolProvider.provider == str(provider_id_entity))
  190. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  191. )
  192. .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
  193. .first()
  194. )
  195. if builtin_provider is None:
  196. raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
  197. else:
  198. builtin_provider = (
  199. db.session.query(BuiltinToolProvider)
  200. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  201. .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
  202. .first()
  203. )
  204. if builtin_provider is None:
  205. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  206. encrypter, _ = create_provider_encrypter(
  207. tenant_id=tenant_id,
  208. config=[
  209. x.to_basic_provider_config()
  210. for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
  211. ],
  212. cache=ToolProviderCredentialsCache(
  213. tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
  214. ),
  215. )
  216. return cast(
  217. BuiltinTool,
  218. builtin_tool.fork_tool_runtime(
  219. runtime=ToolRuntime(
  220. tenant_id=tenant_id,
  221. credentials=encrypter.decrypt(builtin_provider.credentials),
  222. credential_type=CredentialType.of(builtin_provider.credential_type),
  223. runtime_parameters={},
  224. invoke_from=invoke_from,
  225. tool_invoke_from=tool_invoke_from,
  226. )
  227. ),
  228. )
  229. elif provider_type == ToolProviderType.API:
  230. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  231. encrypter, _ = create_tool_provider_encrypter(
  232. tenant_id=tenant_id,
  233. controller=api_provider,
  234. )
  235. return cast(
  236. ApiTool,
  237. api_provider.get_tool(tool_name).fork_tool_runtime(
  238. runtime=ToolRuntime(
  239. tenant_id=tenant_id,
  240. credentials=encrypter.decrypt(credentials),
  241. invoke_from=invoke_from,
  242. tool_invoke_from=tool_invoke_from,
  243. )
  244. ),
  245. )
  246. elif provider_type == ToolProviderType.WORKFLOW:
  247. workflow_provider = (
  248. db.session.query(WorkflowToolProvider)
  249. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  250. .first()
  251. )
  252. if workflow_provider is None:
  253. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  254. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  255. controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
  256. if controller_tools is None or len(controller_tools) == 0:
  257. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  258. return cast(
  259. WorkflowTool,
  260. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  261. runtime=ToolRuntime(
  262. tenant_id=tenant_id,
  263. credentials={},
  264. invoke_from=invoke_from,
  265. tool_invoke_from=tool_invoke_from,
  266. )
  267. ),
  268. )
  269. elif provider_type == ToolProviderType.APP:
  270. raise NotImplementedError("app provider not implemented")
  271. elif provider_type == ToolProviderType.PLUGIN:
  272. return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
  273. elif provider_type == ToolProviderType.MCP:
  274. return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
  275. else:
  276. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  277. @classmethod
  278. def get_agent_tool_runtime(
  279. cls,
  280. tenant_id: str,
  281. app_id: str,
  282. agent_tool: AgentToolEntity,
  283. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  284. variable_pool: Optional[VariablePool] = None,
  285. ) -> Tool:
  286. """
  287. get the agent tool runtime
  288. """
  289. tool_entity = cls.get_tool_runtime(
  290. provider_type=agent_tool.provider_type,
  291. provider_id=agent_tool.provider_id,
  292. tool_name=agent_tool.tool_name,
  293. tenant_id=tenant_id,
  294. invoke_from=invoke_from,
  295. tool_invoke_from=ToolInvokeFrom.AGENT,
  296. credential_id=agent_tool.credential_id,
  297. )
  298. runtime_parameters = {}
  299. parameters = tool_entity.get_merged_runtime_parameters()
  300. runtime_parameters = cls._convert_tool_parameters_type(
  301. parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
  302. )
  303. # decrypt runtime parameters
  304. encryption_manager = ToolParameterConfigurationManager(
  305. tenant_id=tenant_id,
  306. tool_runtime=tool_entity,
  307. provider_name=agent_tool.provider_id,
  308. provider_type=agent_tool.provider_type,
  309. identity_id=f"AGENT.{app_id}",
  310. )
  311. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  312. if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
  313. raise ValueError("runtime not found or runtime parameters not found")
  314. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  315. return tool_entity
  316. @classmethod
  317. def get_workflow_tool_runtime(
  318. cls,
  319. tenant_id: str,
  320. app_id: str,
  321. node_id: str,
  322. workflow_tool: "ToolEntity",
  323. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  324. variable_pool: Optional[VariablePool] = None,
  325. ) -> Tool:
  326. """
  327. get the workflow tool runtime
  328. """
  329. tool_runtime = cls.get_tool_runtime(
  330. provider_type=workflow_tool.provider_type,
  331. provider_id=workflow_tool.provider_id,
  332. tool_name=workflow_tool.tool_name,
  333. tenant_id=tenant_id,
  334. invoke_from=invoke_from,
  335. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  336. credential_id=workflow_tool.credential_id,
  337. )
  338. parameters = tool_runtime.get_merged_runtime_parameters()
  339. runtime_parameters = cls._convert_tool_parameters_type(
  340. parameters, variable_pool, workflow_tool.tool_configurations, typ="workflow"
  341. )
  342. # decrypt runtime parameters
  343. encryption_manager = ToolParameterConfigurationManager(
  344. tenant_id=tenant_id,
  345. tool_runtime=tool_runtime,
  346. provider_name=workflow_tool.provider_id,
  347. provider_type=workflow_tool.provider_type,
  348. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  349. )
  350. if runtime_parameters:
  351. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  352. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  353. return tool_runtime
  354. @classmethod
  355. def get_tool_runtime_from_plugin(
  356. cls,
  357. tool_type: ToolProviderType,
  358. tenant_id: str,
  359. provider: str,
  360. tool_name: str,
  361. tool_parameters: dict[str, Any],
  362. credential_id: Optional[str] = None,
  363. ) -> Tool:
  364. """
  365. get tool runtime from plugin
  366. """
  367. tool_entity = cls.get_tool_runtime(
  368. provider_type=tool_type,
  369. provider_id=provider,
  370. tool_name=tool_name,
  371. tenant_id=tenant_id,
  372. invoke_from=InvokeFrom.SERVICE_API,
  373. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  374. credential_id=credential_id,
  375. )
  376. runtime_parameters = {}
  377. parameters = tool_entity.get_merged_runtime_parameters()
  378. for parameter in parameters:
  379. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  380. # save tool parameter to tool entity memory
  381. value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name))
  382. runtime_parameters[parameter.name] = value
  383. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  384. return tool_entity
  385. @classmethod
  386. def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
  387. """
  388. get the absolute path of the icon of the hardcoded provider
  389. :param provider: the name of the provider
  390. :return: the absolute path of the icon, the mime type of the icon
  391. """
  392. # get provider
  393. provider_controller = cls.get_hardcoded_provider(provider)
  394. absolute_path = path.join(
  395. path.dirname(path.realpath(__file__)),
  396. "builtin_tool",
  397. "providers",
  398. provider,
  399. "_assets",
  400. provider_controller.entity.identity.icon,
  401. )
  402. # check if the icon exists
  403. if not path.exists(absolute_path):
  404. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  405. # get the mime type
  406. mime_type, _ = mimetypes.guess_type(absolute_path)
  407. mime_type = mime_type or "application/octet-stream"
  408. return absolute_path, mime_type
  409. @classmethod
  410. def list_hardcoded_providers(cls):
  411. # use cache first
  412. if cls._builtin_providers_loaded:
  413. yield from list(cls._hardcoded_providers.values())
  414. return
  415. with cls._builtin_provider_lock:
  416. if cls._builtin_providers_loaded:
  417. yield from list(cls._hardcoded_providers.values())
  418. return
  419. yield from cls._list_hardcoded_providers()
  420. @classmethod
  421. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  422. """
  423. list all the plugin providers
  424. """
  425. manager = PluginToolManager()
  426. provider_entities = manager.fetch_tool_providers(tenant_id)
  427. return [
  428. PluginToolProviderController(
  429. entity=provider.declaration,
  430. plugin_id=provider.plugin_id,
  431. plugin_unique_identifier=provider.plugin_unique_identifier,
  432. tenant_id=tenant_id,
  433. )
  434. for provider in provider_entities
  435. ]
  436. @classmethod
  437. def list_builtin_providers(
  438. cls, tenant_id: str
  439. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  440. """
  441. list all the builtin providers
  442. """
  443. yield from cls.list_hardcoded_providers()
  444. # get plugin providers
  445. yield from cls.list_plugin_providers(tenant_id)
  446. @classmethod
  447. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  448. """
  449. list all the builtin providers
  450. """
  451. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  452. if provider_path.startswith("__"):
  453. continue
  454. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  455. if provider_path.startswith("__"):
  456. continue
  457. # init provider
  458. try:
  459. provider_class = load_single_subclass_from_source(
  460. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  461. script_path=path.join(
  462. path.dirname(path.realpath(__file__)),
  463. "builtin_tool",
  464. "providers",
  465. provider_path,
  466. f"{provider_path}.py",
  467. ),
  468. parent_type=BuiltinToolProviderController,
  469. )
  470. provider: BuiltinToolProviderController = provider_class()
  471. cls._hardcoded_providers[provider.entity.identity.name] = provider
  472. for tool in provider.get_tools():
  473. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  474. yield provider
  475. except Exception:
  476. logger.exception(f"load builtin provider {provider_path}")
  477. continue
  478. # set builtin providers loaded
  479. cls._builtin_providers_loaded = True
  480. @classmethod
  481. def load_hardcoded_providers_cache(cls):
  482. for _ in cls.list_hardcoded_providers():
  483. pass
  484. @classmethod
  485. def clear_hardcoded_providers_cache(cls):
  486. cls._hardcoded_providers = {}
  487. cls._builtin_providers_loaded = False
  488. @classmethod
  489. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  490. """
  491. get the tool label
  492. :param tool_name: the name of the tool
  493. :return: the label of the tool
  494. """
  495. if len(cls._builtin_tools_labels) == 0:
  496. # init the builtin providers
  497. cls.load_hardcoded_providers_cache()
  498. if tool_name not in cls._builtin_tools_labels:
  499. return None
  500. return cls._builtin_tools_labels[tool_name]
  501. @classmethod
  502. def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
  503. """
  504. list all the builtin providers
  505. """
  506. # according to multi credentials, select the one with is_default=True first, then created_at oldest
  507. # for compatibility with old version
  508. sql = """
  509. SELECT DISTINCT ON (tenant_id, provider) id
  510. FROM tool_builtin_providers
  511. WHERE tenant_id = :tenant_id
  512. ORDER BY tenant_id, provider, is_default DESC, created_at DESC
  513. """
  514. ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
  515. return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
  516. @classmethod
  517. def list_providers_from_api(
  518. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  519. ) -> list[ToolProviderApiEntity]:
  520. result_providers: dict[str, ToolProviderApiEntity] = {}
  521. filters = []
  522. if not typ:
  523. filters.extend(["builtin", "api", "workflow", "mcp"])
  524. else:
  525. filters.append(typ)
  526. with db.session.no_autoflush:
  527. if "builtin" in filters:
  528. builtin_providers = cls.list_builtin_providers(tenant_id)
  529. # key: provider name, value: provider
  530. db_builtin_providers = {
  531. str(ToolProviderID(provider.provider)): provider
  532. for provider in cls.list_default_builtin_providers(tenant_id)
  533. }
  534. # append builtin providers
  535. for provider in builtin_providers:
  536. # handle include, exclude
  537. if is_filtered(
  538. include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
  539. exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
  540. data=provider,
  541. name_func=lambda x: x.identity.name,
  542. ):
  543. continue
  544. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  545. provider_controller=provider,
  546. db_provider=db_builtin_providers.get(provider.entity.identity.name),
  547. decrypt_credentials=False,
  548. )
  549. if isinstance(provider, PluginToolProviderController):
  550. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  551. else:
  552. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  553. # get db api providers
  554. if "api" in filters:
  555. db_api_providers: list[ApiToolProvider] = (
  556. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  557. )
  558. api_provider_controllers: list[dict[str, Any]] = [
  559. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  560. for provider in db_api_providers
  561. ]
  562. # get labels
  563. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  564. for api_provider_controller in api_provider_controllers:
  565. user_provider = ToolTransformService.api_provider_to_user_provider(
  566. provider_controller=api_provider_controller["controller"],
  567. db_provider=api_provider_controller["provider"],
  568. decrypt_credentials=False,
  569. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  570. )
  571. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  572. if "workflow" in filters:
  573. # get workflow providers
  574. workflow_providers: list[WorkflowToolProvider] = (
  575. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  576. )
  577. workflow_provider_controllers: list[WorkflowToolProviderController] = []
  578. for workflow_provider in workflow_providers:
  579. try:
  580. workflow_provider_controllers.append(
  581. ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  582. )
  583. except Exception:
  584. # app has been deleted
  585. pass
  586. labels = ToolLabelManager.get_tools_labels(
  587. [cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
  588. )
  589. for provider_controller in workflow_provider_controllers:
  590. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  591. provider_controller=provider_controller,
  592. labels=labels.get(provider_controller.provider_id, []),
  593. )
  594. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  595. if "mcp" in filters:
  596. mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
  597. for mcp_provider in mcp_providers:
  598. result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
  599. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  600. @classmethod
  601. def get_api_provider_controller(
  602. cls, tenant_id: str, provider_id: str
  603. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  604. """
  605. get the api provider
  606. :param tenant_id: the id of the tenant
  607. :param provider_id: the id of the provider
  608. :return: the provider controller, the credentials
  609. """
  610. provider: ApiToolProvider | None = (
  611. db.session.query(ApiToolProvider)
  612. .filter(
  613. ApiToolProvider.id == provider_id,
  614. ApiToolProvider.tenant_id == tenant_id,
  615. )
  616. .first()
  617. )
  618. if provider is None:
  619. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  620. auth_type = ApiProviderAuthType.NONE
  621. provider_auth_type = provider.credentials.get("auth_type")
  622. if provider_auth_type in ("api_key_header", "api_key"): # backward compatibility
  623. auth_type = ApiProviderAuthType.API_KEY_HEADER
  624. elif provider_auth_type == "api_key_query":
  625. auth_type = ApiProviderAuthType.API_KEY_QUERY
  626. controller = ApiToolProviderController.from_db(
  627. provider,
  628. auth_type,
  629. )
  630. controller.load_bundled_tools(provider.tools)
  631. return controller, provider.credentials
  632. @classmethod
  633. def get_mcp_provider_controller(cls, tenant_id: str, provider_id: str) -> MCPToolProviderController:
  634. """
  635. get the api provider
  636. :param tenant_id: the id of the tenant
  637. :param provider_id: the id of the provider
  638. :return: the provider controller, the credentials
  639. """
  640. provider: MCPToolProvider | None = (
  641. db.session.query(MCPToolProvider)
  642. .filter(
  643. MCPToolProvider.server_identifier == provider_id,
  644. MCPToolProvider.tenant_id == tenant_id,
  645. )
  646. .first()
  647. )
  648. if provider is None:
  649. raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
  650. controller = MCPToolProviderController._from_db(provider)
  651. return controller
  652. @classmethod
  653. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  654. """
  655. get api provider
  656. """
  657. """
  658. get tool provider
  659. """
  660. provider_name = provider
  661. provider_obj: ApiToolProvider | None = (
  662. db.session.query(ApiToolProvider)
  663. .filter(
  664. ApiToolProvider.tenant_id == tenant_id,
  665. ApiToolProvider.name == provider,
  666. )
  667. .first()
  668. )
  669. if provider_obj is None:
  670. raise ValueError(f"you have not added provider {provider_name}")
  671. try:
  672. credentials = json.loads(provider_obj.credentials_str) or {}
  673. except Exception:
  674. credentials = {}
  675. # package tool provider controller
  676. auth_type = ApiProviderAuthType.NONE
  677. credentials_auth_type = credentials.get("auth_type")
  678. if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
  679. auth_type = ApiProviderAuthType.API_KEY_HEADER
  680. elif credentials_auth_type == "api_key_query":
  681. auth_type = ApiProviderAuthType.API_KEY_QUERY
  682. controller = ApiToolProviderController.from_db(
  683. provider_obj,
  684. auth_type,
  685. )
  686. # init tool configuration
  687. encrypter, _ = create_tool_provider_encrypter(
  688. tenant_id=tenant_id,
  689. controller=controller,
  690. )
  691. masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
  692. try:
  693. icon = json.loads(provider_obj.icon)
  694. except Exception:
  695. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  696. # add tool labels
  697. labels = ToolLabelManager.get_tool_labels(controller)
  698. return cast(
  699. dict,
  700. jsonable_encoder(
  701. {
  702. "schema_type": provider_obj.schema_type,
  703. "schema": provider_obj.schema,
  704. "tools": provider_obj.tools,
  705. "icon": icon,
  706. "description": provider_obj.description,
  707. "credentials": masked_credentials,
  708. "privacy_policy": provider_obj.privacy_policy,
  709. "custom_disclaimer": provider_obj.custom_disclaimer,
  710. "labels": labels,
  711. }
  712. ),
  713. )
  714. @classmethod
  715. def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
  716. return str(
  717. URL(dify_config.CONSOLE_API_URL or "/")
  718. / "console"
  719. / "api"
  720. / "workspaces"
  721. / "current"
  722. / "tool-provider"
  723. / "builtin"
  724. / provider_id
  725. / "icon"
  726. )
  727. @classmethod
  728. def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
  729. return str(
  730. URL(dify_config.CONSOLE_API_URL or "/")
  731. / "console"
  732. / "api"
  733. / "workspaces"
  734. / "current"
  735. / "plugin"
  736. / "icon"
  737. % {"tenant_id": tenant_id, "filename": filename}
  738. )
  739. @classmethod
  740. def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  741. try:
  742. workflow_provider: WorkflowToolProvider | None = (
  743. db.session.query(WorkflowToolProvider)
  744. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  745. .first()
  746. )
  747. if workflow_provider is None:
  748. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  749. icon: dict = json.loads(workflow_provider.icon)
  750. return icon
  751. except Exception:
  752. return {"background": "#252525", "content": "\ud83d\ude01"}
  753. @classmethod
  754. def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  755. try:
  756. api_provider: ApiToolProvider | None = (
  757. db.session.query(ApiToolProvider)
  758. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  759. .first()
  760. )
  761. if api_provider is None:
  762. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  763. icon: dict = json.loads(api_provider.icon)
  764. return icon
  765. except Exception:
  766. return {"background": "#252525", "content": "\ud83d\ude01"}
  767. @classmethod
  768. def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
  769. try:
  770. mcp_provider: MCPToolProvider | None = (
  771. db.session.query(MCPToolProvider)
  772. .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
  773. .first()
  774. )
  775. if mcp_provider is None:
  776. raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
  777. return mcp_provider.provider_icon
  778. except Exception:
  779. return {"background": "#252525", "content": "\ud83d\ude01"}
  780. @classmethod
  781. def get_tool_icon(
  782. cls,
  783. tenant_id: str,
  784. provider_type: ToolProviderType,
  785. provider_id: str,
  786. ) -> Union[str, dict]:
  787. """
  788. get the tool icon
  789. :param tenant_id: the id of the tenant
  790. :param provider_type: the type of the provider
  791. :param provider_id: the id of the provider
  792. :return:
  793. """
  794. provider_type = provider_type
  795. provider_id = provider_id
  796. if provider_type == ToolProviderType.BUILT_IN:
  797. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  798. if isinstance(provider, PluginToolProviderController):
  799. try:
  800. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  801. except Exception:
  802. return {"background": "#252525", "content": "\ud83d\ude01"}
  803. return cls.generate_builtin_tool_icon_url(provider_id)
  804. elif provider_type == ToolProviderType.API:
  805. return cls.generate_api_tool_icon_url(tenant_id, provider_id)
  806. elif provider_type == ToolProviderType.WORKFLOW:
  807. return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
  808. elif provider_type == ToolProviderType.PLUGIN:
  809. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  810. if isinstance(provider, PluginToolProviderController):
  811. try:
  812. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  813. except Exception:
  814. return {"background": "#252525", "content": "\ud83d\ude01"}
  815. raise ValueError(f"plugin provider {provider_id} not found")
  816. elif provider_type == ToolProviderType.MCP:
  817. return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
  818. else:
  819. raise ValueError(f"provider type {provider_type} not found")
  820. @classmethod
  821. def _convert_tool_parameters_type(
  822. cls,
  823. parameters: list[ToolParameter],
  824. variable_pool: Optional[VariablePool],
  825. tool_configurations: dict[str, Any],
  826. typ: Literal["agent", "workflow", "tool"] = "workflow",
  827. ) -> dict[str, Any]:
  828. """
  829. Convert tool parameters type
  830. """
  831. from core.workflow.nodes.tool.entities import ToolNodeData
  832. from core.workflow.nodes.tool.exc import ToolParameterError
  833. runtime_parameters = {}
  834. for parameter in parameters:
  835. if (
  836. parameter.type
  837. in {
  838. ToolParameter.ToolParameterType.SYSTEM_FILES,
  839. ToolParameter.ToolParameterType.FILE,
  840. ToolParameter.ToolParameterType.FILES,
  841. }
  842. and parameter.required
  843. and typ == "agent"
  844. ):
  845. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  846. # save tool parameter to tool entity memory
  847. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  848. if variable_pool:
  849. config = tool_configurations.get(parameter.name, {})
  850. if not (config and isinstance(config, dict) and config.get("value") is not None):
  851. continue
  852. tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
  853. if tool_input.type == "variable":
  854. variable = variable_pool.get(tool_input.value)
  855. if variable is None:
  856. raise ToolParameterError(f"Variable {tool_input.value} does not exist")
  857. parameter_value = variable.value
  858. elif tool_input.type in {"mixed", "constant"}:
  859. segment_group = variable_pool.convert_template(str(tool_input.value))
  860. parameter_value = segment_group.text
  861. else:
  862. raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
  863. runtime_parameters[parameter.name] = parameter_value
  864. else:
  865. value = parameter.init_frontend_parameter(tool_configurations.get(parameter.name))
  866. runtime_parameters[parameter.name] = value
  867. return runtime_parameters
  868. ToolManager.load_hardcoded_providers_cache()