You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tools_transform_service.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. import json
  2. import logging
  3. from typing import Any, Optional, Union, cast
  4. from yarl import URL
  5. from configs import dify_config
  6. from core.helper.provider_cache import ToolProviderCredentialsCache
  7. from core.mcp.types import Tool as MCPTool
  8. from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
  9. from core.tools.__base.tool import Tool
  10. from core.tools.__base.tool_runtime import ToolRuntime
  11. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  12. from core.tools.custom_tool.provider import ApiToolProviderController
  13. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
  14. from core.tools.entities.common_entities import I18nObject
  15. from core.tools.entities.tool_bundle import ApiToolBundle
  16. from core.tools.entities.tool_entities import (
  17. ApiProviderAuthType,
  18. CredentialType,
  19. ToolParameter,
  20. ToolProviderType,
  21. )
  22. from core.tools.plugin_tool.provider import PluginToolProviderController
  23. from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
  24. from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
  25. from core.tools.workflow_as_tool.tool import WorkflowTool
  26. from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
  27. logger = logging.getLogger(__name__)
  28. class ToolTransformService:
  29. @classmethod
  30. def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
  31. url_prefix = (
  32. URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
  33. )
  34. return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
  35. @classmethod
  36. def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
  37. """
  38. get tool provider icon url
  39. """
  40. url_prefix = (
  41. URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
  42. )
  43. if provider_type == ToolProviderType.BUILT_IN.value:
  44. return str(url_prefix / "builtin" / provider_name / "icon")
  45. elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
  46. try:
  47. if isinstance(icon, str):
  48. return cast(dict, json.loads(icon))
  49. return icon
  50. except Exception:
  51. return {"background": "#252525", "content": "\ud83d\ude01"}
  52. elif provider_type == ToolProviderType.MCP.value:
  53. return icon
  54. return ""
  55. @staticmethod
  56. def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
  57. """
  58. repack provider
  59. :param tenant_id: the tenant id
  60. :param provider: the provider dict
  61. """
  62. if isinstance(provider, dict) and "icon" in provider:
  63. provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
  64. provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
  65. )
  66. elif isinstance(provider, ToolProviderApiEntity):
  67. if provider.plugin_id:
  68. if isinstance(provider.icon, str):
  69. provider.icon = ToolTransformService.get_plugin_icon_url(
  70. tenant_id=tenant_id, filename=provider.icon
  71. )
  72. if isinstance(provider.icon_dark, str) and provider.icon_dark:
  73. provider.icon_dark = ToolTransformService.get_plugin_icon_url(
  74. tenant_id=tenant_id, filename=provider.icon_dark
  75. )
  76. else:
  77. provider.icon = ToolTransformService.get_tool_provider_icon_url(
  78. provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
  79. )
  80. if provider.icon_dark:
  81. provider.icon_dark = ToolTransformService.get_tool_provider_icon_url(
  82. provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon_dark
  83. )
  84. elif isinstance(provider, PluginDatasourceProviderEntity):
  85. if provider.plugin_id:
  86. if isinstance(provider.declaration.identity.icon, str):
  87. provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url(
  88. tenant_id=tenant_id, filename=provider.declaration.identity.icon
  89. )
  90. else:
  91. provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url(
  92. provider_type=provider.type.value,
  93. provider_name=provider.name,
  94. icon=provider.declaration.identity.icon,
  95. )
  96. @classmethod
  97. def builtin_provider_to_user_provider(
  98. cls,
  99. provider_controller: BuiltinToolProviderController | PluginToolProviderController,
  100. db_provider: Optional[BuiltinToolProvider],
  101. decrypt_credentials: bool = True,
  102. ) -> ToolProviderApiEntity:
  103. """
  104. convert provider controller to user provider
  105. """
  106. result = ToolProviderApiEntity(
  107. id=provider_controller.entity.identity.name,
  108. author=provider_controller.entity.identity.author,
  109. name=provider_controller.entity.identity.name,
  110. description=provider_controller.entity.identity.description,
  111. icon=provider_controller.entity.identity.icon,
  112. icon_dark=provider_controller.entity.identity.icon_dark,
  113. label=provider_controller.entity.identity.label,
  114. type=ToolProviderType.BUILT_IN,
  115. masked_credentials={},
  116. is_team_authorization=False,
  117. plugin_id=None,
  118. tools=[],
  119. labels=provider_controller.tool_labels,
  120. )
  121. if isinstance(provider_controller, PluginToolProviderController):
  122. result.plugin_id = provider_controller.plugin_id
  123. result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
  124. # get credentials schema
  125. schema = {
  126. x.to_basic_provider_config().name: x
  127. for x in provider_controller.get_credentials_schema_by_type(
  128. CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
  129. )
  130. }
  131. for name, value in schema.items():
  132. if result.masked_credentials:
  133. result.masked_credentials[name] = ""
  134. # check if the provider need credentials
  135. if not provider_controller.need_credentials:
  136. result.is_team_authorization = True
  137. result.allow_delete = False
  138. elif db_provider:
  139. result.is_team_authorization = True
  140. if decrypt_credentials:
  141. credentials = db_provider.credentials
  142. # init tool configuration
  143. encrypter, _ = create_provider_encrypter(
  144. tenant_id=db_provider.tenant_id,
  145. config=[
  146. x.to_basic_provider_config()
  147. for x in provider_controller.get_credentials_schema_by_type(
  148. CredentialType.of(db_provider.credential_type)
  149. )
  150. ],
  151. cache=ToolProviderCredentialsCache(
  152. tenant_id=db_provider.tenant_id,
  153. provider=db_provider.provider,
  154. credential_id=db_provider.id,
  155. ),
  156. )
  157. # decrypt the credentials and mask the credentials
  158. decrypted_credentials = encrypter.decrypt(data=credentials)
  159. masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
  160. result.masked_credentials = masked_credentials
  161. result.original_credentials = decrypted_credentials
  162. return result
  163. @staticmethod
  164. def api_provider_to_controller(
  165. db_provider: ApiToolProvider,
  166. ) -> ApiToolProviderController:
  167. """
  168. convert provider controller to user provider
  169. """
  170. # package tool provider controller
  171. auth_type = ApiProviderAuthType.NONE
  172. credentials_auth_type = db_provider.credentials.get("auth_type")
  173. if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
  174. auth_type = ApiProviderAuthType.API_KEY_HEADER
  175. elif credentials_auth_type == "api_key_query":
  176. auth_type = ApiProviderAuthType.API_KEY_QUERY
  177. controller = ApiToolProviderController.from_db(
  178. db_provider=db_provider,
  179. auth_type=auth_type,
  180. )
  181. return controller
  182. @staticmethod
  183. def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
  184. """
  185. convert provider controller to provider
  186. """
  187. return WorkflowToolProviderController.from_db(db_provider)
  188. @staticmethod
  189. def workflow_provider_to_user_provider(
  190. provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
  191. ):
  192. """
  193. convert provider controller to user provider
  194. """
  195. return ToolProviderApiEntity(
  196. id=provider_controller.provider_id,
  197. author=provider_controller.entity.identity.author,
  198. name=provider_controller.entity.identity.name,
  199. description=provider_controller.entity.identity.description,
  200. icon=provider_controller.entity.identity.icon,
  201. icon_dark=provider_controller.entity.identity.icon_dark,
  202. label=provider_controller.entity.identity.label,
  203. type=ToolProviderType.WORKFLOW,
  204. masked_credentials={},
  205. is_team_authorization=True,
  206. plugin_id=None,
  207. plugin_unique_identifier=None,
  208. tools=[],
  209. labels=labels or [],
  210. )
  211. @staticmethod
  212. def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
  213. user = db_provider.load_user()
  214. return ToolProviderApiEntity(
  215. id=db_provider.server_identifier if not for_list else db_provider.id,
  216. author=user.name if user else "Anonymous",
  217. name=db_provider.name,
  218. icon=db_provider.provider_icon,
  219. type=ToolProviderType.MCP,
  220. is_team_authorization=db_provider.authed,
  221. server_url=db_provider.masked_server_url,
  222. tools=ToolTransformService.mcp_tool_to_user_tool(
  223. db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
  224. ),
  225. updated_at=int(db_provider.updated_at.timestamp()),
  226. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  227. description=I18nObject(en_US="", zh_Hans=""),
  228. server_identifier=db_provider.server_identifier,
  229. )
  230. @staticmethod
  231. def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
  232. user = mcp_provider.load_user()
  233. return [
  234. ToolApiEntity(
  235. author=user.name if user else "Anonymous",
  236. name=tool.name,
  237. label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
  238. description=I18nObject(en_US=tool.description, zh_Hans=tool.description),
  239. parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
  240. labels=[],
  241. )
  242. for tool in tools
  243. ]
  244. @classmethod
  245. def api_provider_to_user_provider(
  246. cls,
  247. provider_controller: ApiToolProviderController,
  248. db_provider: ApiToolProvider,
  249. decrypt_credentials: bool = True,
  250. labels: list[str] | None = None,
  251. ) -> ToolProviderApiEntity:
  252. """
  253. convert provider controller to user provider
  254. """
  255. username = "Anonymous"
  256. if db_provider.user is None:
  257. raise ValueError(f"user is None for api provider {db_provider.id}")
  258. try:
  259. user = db_provider.user
  260. if not user:
  261. raise ValueError("user not found")
  262. username = user.name
  263. except Exception:
  264. logger.exception("failed to get user name for api provider %s", db_provider.id)
  265. # add provider into providers
  266. credentials = db_provider.credentials
  267. result = ToolProviderApiEntity(
  268. id=db_provider.id,
  269. author=username,
  270. name=db_provider.name,
  271. description=I18nObject(
  272. en_US=db_provider.description,
  273. zh_Hans=db_provider.description,
  274. ),
  275. icon=db_provider.icon,
  276. label=I18nObject(
  277. en_US=db_provider.name,
  278. zh_Hans=db_provider.name,
  279. ),
  280. type=ToolProviderType.API,
  281. plugin_id=None,
  282. plugin_unique_identifier=None,
  283. masked_credentials={},
  284. is_team_authorization=True,
  285. tools=[],
  286. labels=labels or [],
  287. )
  288. if decrypt_credentials:
  289. # init tool configuration
  290. encrypter, _ = create_tool_provider_encrypter(
  291. tenant_id=db_provider.tenant_id,
  292. controller=provider_controller,
  293. )
  294. # decrypt the credentials and mask the credentials
  295. decrypted_credentials = encrypter.decrypt(data=credentials)
  296. masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
  297. result.masked_credentials = masked_credentials
  298. return result
  299. @staticmethod
  300. def convert_tool_entity_to_api_entity(
  301. tool: Union[ApiToolBundle, WorkflowTool, Tool],
  302. tenant_id: str,
  303. labels: list[str] | None = None,
  304. ) -> ToolApiEntity:
  305. """
  306. convert tool to user tool
  307. """
  308. if isinstance(tool, Tool):
  309. # fork tool runtime
  310. tool = tool.fork_tool_runtime(
  311. runtime=ToolRuntime(
  312. credentials={},
  313. tenant_id=tenant_id,
  314. )
  315. )
  316. # get tool parameters
  317. base_parameters = tool.entity.parameters or []
  318. # get tool runtime parameters
  319. runtime_parameters = tool.get_runtime_parameters()
  320. # merge parameters using a functional approach to avoid type issues
  321. merged_parameters: list[ToolParameter] = []
  322. # create a mapping of runtime parameters for quick lookup
  323. runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters}
  324. # process base parameters, replacing with runtime versions if they exist
  325. for base_param in base_parameters:
  326. key = (base_param.name, base_param.form)
  327. if key in runtime_param_map:
  328. merged_parameters.append(runtime_param_map[key])
  329. else:
  330. merged_parameters.append(base_param)
  331. # add any runtime parameters that weren't in base parameters
  332. for runtime_parameter in runtime_parameters:
  333. if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  334. # check if this parameter is already in merged_parameters
  335. already_exists = any(
  336. p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters
  337. )
  338. if not already_exists:
  339. merged_parameters.append(runtime_parameter)
  340. return ToolApiEntity(
  341. author=tool.entity.identity.author,
  342. name=tool.entity.identity.name,
  343. label=tool.entity.identity.label,
  344. description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
  345. output_schema=tool.entity.output_schema,
  346. parameters=merged_parameters,
  347. labels=labels or [],
  348. )
  349. elif isinstance(tool, ApiToolBundle):
  350. return ToolApiEntity(
  351. author=tool.author,
  352. name=tool.operation_id or "",
  353. label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
  354. description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
  355. parameters=tool.parameters,
  356. labels=labels or [],
  357. )
  358. else:
  359. # Handle WorkflowTool case
  360. raise ValueError(f"Unsupported tool type: {type(tool)}")
  361. @staticmethod
  362. def convert_builtin_provider_to_credential_entity(
  363. provider: BuiltinToolProvider, credentials: dict
  364. ) -> ToolProviderCredentialApiEntity:
  365. return ToolProviderCredentialApiEntity(
  366. id=provider.id,
  367. name=provider.name,
  368. provider=provider.provider,
  369. credential_type=CredentialType.of(provider.credential_type),
  370. is_default=provider.is_default,
  371. credentials=credentials,
  372. )
  373. @staticmethod
  374. def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
  375. """
  376. Convert MCP JSON schema to tool parameters
  377. :param schema: JSON schema dictionary
  378. :return: list of ToolParameter instances
  379. """
  380. def create_parameter(
  381. name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
  382. ) -> ToolParameter:
  383. """Create a ToolParameter instance with given attributes"""
  384. input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
  385. return ToolParameter(
  386. name=name,
  387. llm_description=description,
  388. label=I18nObject(en_US=name),
  389. form=ToolParameter.ToolParameterForm.LLM,
  390. required=required,
  391. type=ToolParameter.ToolParameterType(param_type),
  392. human_description=I18nObject(en_US=description),
  393. **input_schema_dict,
  394. )
  395. def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
  396. """Process properties recursively"""
  397. TYPE_MAPPING = {"integer": "number", "float": "number"}
  398. COMPLEX_TYPES = ["array", "object"]
  399. parameters = []
  400. for name, prop in props.items():
  401. current_description = prop.get("description", "")
  402. prop_type = prop.get("type", "string")
  403. if isinstance(prop_type, list):
  404. prop_type = prop_type[0]
  405. if prop_type in TYPE_MAPPING:
  406. prop_type = TYPE_MAPPING[prop_type]
  407. input_schema = prop if prop_type in COMPLEX_TYPES else None
  408. parameters.append(
  409. create_parameter(name, current_description, prop_type, name in required, input_schema)
  410. )
  411. return parameters
  412. if schema.get("type") == "object" and "properties" in schema:
  413. return process_properties(schema["properties"], schema.get("required", []))
  414. return []