You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tools_transform_service.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. import json
  2. import logging
  3. from collections.abc import Mapping
  4. from typing import Any, Union
  5. from yarl import URL
  6. from configs import dify_config
  7. from core.helper.provider_cache import ToolProviderCredentialsCache
  8. from core.mcp.types import Tool as MCPTool
  9. from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
  10. from core.tools.__base.tool import Tool
  11. from core.tools.__base.tool_runtime import ToolRuntime
  12. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  13. from core.tools.custom_tool.provider import ApiToolProviderController
  14. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
  15. from core.tools.entities.common_entities import I18nObject
  16. from core.tools.entities.tool_bundle import ApiToolBundle
  17. from core.tools.entities.tool_entities import (
  18. ApiProviderAuthType,
  19. CredentialType,
  20. ToolParameter,
  21. ToolProviderType,
  22. )
  23. from core.tools.plugin_tool.provider import PluginToolProviderController
  24. from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
  25. from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
  26. from core.tools.workflow_as_tool.tool import WorkflowTool
  27. from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
  28. logger = logging.getLogger(__name__)
  29. class ToolTransformService:
  30. @classmethod
  31. def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
  32. url_prefix = (
  33. URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
  34. )
  35. return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
  36. @classmethod
  37. def get_tool_provider_icon_url(
  38. cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
  39. ) -> str | Mapping[str, str]:
  40. """
  41. get tool provider icon url
  42. """
  43. url_prefix = (
  44. URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
  45. )
  46. if provider_type == ToolProviderType.BUILT_IN.value:
  47. return str(url_prefix / "builtin" / provider_name / "icon")
  48. elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
  49. try:
  50. if isinstance(icon, str):
  51. return json.loads(icon)
  52. return icon
  53. except Exception:
  54. return {"background": "#252525", "content": "\ud83d\ude01"}
  55. elif provider_type == ToolProviderType.MCP.value:
  56. return icon
  57. return ""
  58. @staticmethod
  59. def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
  60. """
  61. repack provider
  62. :param tenant_id: the tenant id
  63. :param provider: the provider dict
  64. """
  65. if isinstance(provider, dict) and "icon" in provider:
  66. provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
  67. provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
  68. )
  69. elif isinstance(provider, ToolProviderApiEntity):
  70. if provider.plugin_id:
  71. if isinstance(provider.icon, str):
  72. provider.icon = ToolTransformService.get_plugin_icon_url(
  73. tenant_id=tenant_id, filename=provider.icon
  74. )
  75. if isinstance(provider.icon_dark, str) and provider.icon_dark:
  76. provider.icon_dark = ToolTransformService.get_plugin_icon_url(
  77. tenant_id=tenant_id, filename=provider.icon_dark
  78. )
  79. else:
  80. provider.icon = ToolTransformService.get_tool_provider_icon_url(
  81. provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
  82. )
  83. if provider.icon_dark:
  84. provider.icon_dark = ToolTransformService.get_tool_provider_icon_url(
  85. provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon_dark
  86. )
  87. elif isinstance(provider, PluginDatasourceProviderEntity):
  88. if provider.plugin_id:
  89. if isinstance(provider.declaration.identity.icon, str):
  90. provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url(
  91. tenant_id=tenant_id, filename=provider.declaration.identity.icon
  92. )
  93. @classmethod
  94. def builtin_provider_to_user_provider(
  95. cls,
  96. provider_controller: BuiltinToolProviderController | PluginToolProviderController,
  97. db_provider: BuiltinToolProvider | None,
  98. decrypt_credentials: bool = True,
  99. ) -> ToolProviderApiEntity:
  100. """
  101. convert provider controller to user provider
  102. """
  103. result = ToolProviderApiEntity(
  104. id=provider_controller.entity.identity.name,
  105. author=provider_controller.entity.identity.author,
  106. name=provider_controller.entity.identity.name,
  107. description=provider_controller.entity.identity.description,
  108. icon=provider_controller.entity.identity.icon,
  109. icon_dark=provider_controller.entity.identity.icon_dark or "",
  110. label=provider_controller.entity.identity.label,
  111. type=ToolProviderType.BUILT_IN,
  112. masked_credentials={},
  113. is_team_authorization=False,
  114. plugin_id=None,
  115. tools=[],
  116. labels=provider_controller.tool_labels,
  117. )
  118. if isinstance(provider_controller, PluginToolProviderController):
  119. result.plugin_id = provider_controller.plugin_id
  120. result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
  121. # get credentials schema
  122. schema = {
  123. x.to_basic_provider_config().name: x
  124. for x in provider_controller.get_credentials_schema_by_type(
  125. CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
  126. )
  127. }
  128. masked_creds = {}
  129. for name in schema:
  130. masked_creds[name] = ""
  131. result.masked_credentials = masked_creds
  132. # check if the provider need credentials
  133. if not provider_controller.need_credentials:
  134. result.is_team_authorization = True
  135. result.allow_delete = False
  136. elif db_provider:
  137. result.is_team_authorization = True
  138. if decrypt_credentials:
  139. credentials = db_provider.credentials
  140. # init tool configuration
  141. encrypter, _ = create_provider_encrypter(
  142. tenant_id=db_provider.tenant_id,
  143. config=[
  144. x.to_basic_provider_config()
  145. for x in provider_controller.get_credentials_schema_by_type(
  146. CredentialType.of(db_provider.credential_type)
  147. )
  148. ],
  149. cache=ToolProviderCredentialsCache(
  150. tenant_id=db_provider.tenant_id,
  151. provider=db_provider.provider,
  152. credential_id=db_provider.id,
  153. ),
  154. )
  155. # decrypt the credentials and mask the credentials
  156. decrypted_credentials = encrypter.decrypt(data=credentials)
  157. masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
  158. result.masked_credentials = masked_credentials
  159. result.original_credentials = decrypted_credentials
  160. return result
  161. @staticmethod
  162. def api_provider_to_controller(
  163. db_provider: ApiToolProvider,
  164. ) -> ApiToolProviderController:
  165. """
  166. convert provider controller to user provider
  167. """
  168. # package tool provider controller
  169. auth_type = ApiProviderAuthType.NONE
  170. credentials_auth_type = db_provider.credentials.get("auth_type")
  171. if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
  172. auth_type = ApiProviderAuthType.API_KEY_HEADER
  173. elif credentials_auth_type == "api_key_query":
  174. auth_type = ApiProviderAuthType.API_KEY_QUERY
  175. controller = ApiToolProviderController.from_db(
  176. db_provider=db_provider,
  177. auth_type=auth_type,
  178. )
  179. return controller
  180. @staticmethod
  181. def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
  182. """
  183. convert provider controller to provider
  184. """
  185. return WorkflowToolProviderController.from_db(db_provider)
  186. @staticmethod
  187. def workflow_provider_to_user_provider(
  188. provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
  189. ):
  190. """
  191. convert provider controller to user provider
  192. """
  193. return ToolProviderApiEntity(
  194. id=provider_controller.provider_id,
  195. author=provider_controller.entity.identity.author,
  196. name=provider_controller.entity.identity.name,
  197. description=provider_controller.entity.identity.description,
  198. icon=provider_controller.entity.identity.icon,
  199. icon_dark=provider_controller.entity.identity.icon_dark or "",
  200. label=provider_controller.entity.identity.label,
  201. type=ToolProviderType.WORKFLOW,
  202. masked_credentials={},
  203. is_team_authorization=True,
  204. plugin_id=None,
  205. plugin_unique_identifier=None,
  206. tools=[],
  207. labels=labels or [],
  208. )
  209. @staticmethod
  210. def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
  211. user = db_provider.load_user()
  212. return ToolProviderApiEntity(
  213. id=db_provider.server_identifier if not for_list else db_provider.id,
  214. author=user.name if user else "Anonymous",
  215. name=db_provider.name,
  216. icon=db_provider.provider_icon,
  217. type=ToolProviderType.MCP,
  218. is_team_authorization=db_provider.authed,
  219. server_url=db_provider.masked_server_url,
  220. tools=ToolTransformService.mcp_tool_to_user_tool(
  221. db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
  222. ),
  223. updated_at=int(db_provider.updated_at.timestamp()),
  224. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  225. description=I18nObject(en_US="", zh_Hans=""),
  226. server_identifier=db_provider.server_identifier,
  227. timeout=db_provider.timeout,
  228. sse_read_timeout=db_provider.sse_read_timeout,
  229. masked_headers=db_provider.masked_headers,
  230. original_headers=db_provider.decrypted_headers,
  231. )
  232. @staticmethod
  233. def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
  234. user = mcp_provider.load_user()
  235. return [
  236. ToolApiEntity(
  237. author=user.name if user else "Anonymous",
  238. name=tool.name,
  239. label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
  240. description=I18nObject(en_US=tool.description or "", zh_Hans=tool.description or ""),
  241. parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
  242. labels=[],
  243. )
  244. for tool in tools
  245. ]
  246. @classmethod
  247. def api_provider_to_user_provider(
  248. cls,
  249. provider_controller: ApiToolProviderController,
  250. db_provider: ApiToolProvider,
  251. decrypt_credentials: bool = True,
  252. labels: list[str] | None = None,
  253. ) -> ToolProviderApiEntity:
  254. """
  255. convert provider controller to user provider
  256. """
  257. username = "Anonymous"
  258. if db_provider.user is None:
  259. raise ValueError(f"user is None for api provider {db_provider.id}")
  260. try:
  261. user = db_provider.user
  262. if not user:
  263. raise ValueError("user not found")
  264. username = user.name
  265. except Exception:
  266. logger.exception("failed to get user name for api provider %s", db_provider.id)
  267. # add provider into providers
  268. credentials = db_provider.credentials
  269. result = ToolProviderApiEntity(
  270. id=db_provider.id,
  271. author=username,
  272. name=db_provider.name,
  273. description=I18nObject(
  274. en_US=db_provider.description,
  275. zh_Hans=db_provider.description,
  276. ),
  277. icon=db_provider.icon,
  278. label=I18nObject(
  279. en_US=db_provider.name,
  280. zh_Hans=db_provider.name,
  281. ),
  282. type=ToolProviderType.API,
  283. plugin_id=None,
  284. plugin_unique_identifier=None,
  285. masked_credentials={},
  286. is_team_authorization=True,
  287. tools=[],
  288. labels=labels or [],
  289. )
  290. if decrypt_credentials:
  291. # init tool configuration
  292. encrypter, _ = create_tool_provider_encrypter(
  293. tenant_id=db_provider.tenant_id,
  294. controller=provider_controller,
  295. )
  296. # decrypt the credentials and mask the credentials
  297. decrypted_credentials = encrypter.decrypt(data=credentials)
  298. masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
  299. result.masked_credentials = masked_credentials
  300. return result
  301. @staticmethod
  302. def convert_tool_entity_to_api_entity(
  303. tool: ApiToolBundle | WorkflowTool | Tool,
  304. tenant_id: str,
  305. labels: list[str] | None = None,
  306. ) -> ToolApiEntity:
  307. """
  308. convert tool to user tool
  309. """
  310. if isinstance(tool, Tool):
  311. # fork tool runtime
  312. tool = tool.fork_tool_runtime(
  313. runtime=ToolRuntime(
  314. credentials={},
  315. tenant_id=tenant_id,
  316. )
  317. )
  318. # get tool parameters
  319. base_parameters = tool.entity.parameters or []
  320. # get tool runtime parameters
  321. runtime_parameters = tool.get_runtime_parameters()
  322. # merge parameters using a functional approach to avoid type issues
  323. merged_parameters: list[ToolParameter] = []
  324. # create a mapping of runtime parameters for quick lookup
  325. runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters}
  326. # process base parameters, replacing with runtime versions if they exist
  327. for base_param in base_parameters:
  328. key = (base_param.name, base_param.form)
  329. if key in runtime_param_map:
  330. merged_parameters.append(runtime_param_map[key])
  331. else:
  332. merged_parameters.append(base_param)
  333. # add any runtime parameters that weren't in base parameters
  334. for runtime_parameter in runtime_parameters:
  335. if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  336. # check if this parameter is already in merged_parameters
  337. already_exists = any(
  338. p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters
  339. )
  340. if not already_exists:
  341. merged_parameters.append(runtime_parameter)
  342. return ToolApiEntity(
  343. author=tool.entity.identity.author,
  344. name=tool.entity.identity.name,
  345. label=tool.entity.identity.label,
  346. description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
  347. output_schema=tool.entity.output_schema,
  348. parameters=merged_parameters,
  349. labels=labels or [],
  350. )
  351. else:
  352. return ToolApiEntity(
  353. author=tool.author,
  354. name=tool.operation_id or "",
  355. label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
  356. description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
  357. parameters=tool.parameters,
  358. labels=labels or [],
  359. )
  360. @staticmethod
  361. def convert_builtin_provider_to_credential_entity(
  362. provider: BuiltinToolProvider, credentials: dict
  363. ) -> ToolProviderCredentialApiEntity:
  364. return ToolProviderCredentialApiEntity(
  365. id=provider.id,
  366. name=provider.name,
  367. provider=provider.provider,
  368. credential_type=CredentialType.of(provider.credential_type),
  369. is_default=provider.is_default,
  370. credentials=credentials,
  371. )
  372. @staticmethod
  373. def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
  374. """
  375. Convert MCP JSON schema to tool parameters
  376. :param schema: JSON schema dictionary
  377. :return: list of ToolParameter instances
  378. """
  379. def create_parameter(
  380. name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
  381. ) -> ToolParameter:
  382. """Create a ToolParameter instance with given attributes"""
  383. input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
  384. return ToolParameter(
  385. name=name,
  386. llm_description=description,
  387. label=I18nObject(en_US=name),
  388. form=ToolParameter.ToolParameterForm.LLM,
  389. required=required,
  390. type=ToolParameter.ToolParameterType(param_type),
  391. human_description=I18nObject(en_US=description),
  392. **input_schema_dict,
  393. )
  394. def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
  395. """Process properties recursively"""
  396. TYPE_MAPPING = {"integer": "number", "float": "number"}
  397. COMPLEX_TYPES = ["array", "object"]
  398. parameters = []
  399. for name, prop in props.items():
  400. current_description = prop.get("description", "")
  401. prop_type = prop.get("type", "string")
  402. if isinstance(prop_type, list):
  403. prop_type = prop_type[0]
  404. if prop_type in TYPE_MAPPING:
  405. prop_type = TYPE_MAPPING[prop_type]
  406. input_schema = prop if prop_type in COMPLEX_TYPES else None
  407. parameters.append(
  408. create_parameter(name, current_description, prop_type, name in required, input_schema)
  409. )
  410. return parameters
  411. if schema.get("type") == "object" and "properties" in schema:
  412. return process_properties(schema["properties"], schema.get("required", []))
  413. return []