您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

tools_transform_service.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. import json
  2. import logging
  3. from typing import Any, Optional, Union, cast
  4. from yarl import URL
  5. from configs import dify_config
  6. from core.mcp.types import Tool as MCPTool
  7. from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
  8. from core.tools.__base.tool import Tool
  9. from core.tools.__base.tool_runtime import ToolRuntime
  10. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  11. from core.tools.custom_tool.provider import ApiToolProviderController
  12. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
  13. from core.tools.entities.common_entities import I18nObject
  14. from core.tools.entities.tool_bundle import ApiToolBundle
  15. from core.tools.entities.tool_entities import (
  16. ApiProviderAuthType,
  17. ToolParameter,
  18. ToolProviderType,
  19. )
  20. from core.tools.plugin_tool.provider import PluginToolProviderController
  21. from core.tools.utils.configuration import ProviderConfigEncrypter
  22. from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
  23. from core.tools.workflow_as_tool.tool import WorkflowTool
  24. from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
  25. logger = logging.getLogger(__name__)
  26. class ToolTransformService:
  27. @classmethod
  28. def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
  29. url_prefix = (
  30. URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
  31. )
  32. return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
  33. @classmethod
  34. def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
  35. """
  36. get tool provider icon url
  37. """
  38. url_prefix = (
  39. URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
  40. )
  41. if provider_type == ToolProviderType.BUILT_IN.value:
  42. return str(url_prefix / "builtin" / provider_name / "icon")
  43. elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
  44. try:
  45. if isinstance(icon, str):
  46. return cast(dict, json.loads(icon))
  47. return icon
  48. except Exception:
  49. return {"background": "#252525", "content": "\ud83d\ude01"}
  50. elif provider_type == ToolProviderType.MCP.value:
  51. return icon
  52. return ""
  53. @staticmethod
  54. def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
  55. """
  56. repack provider
  57. :param tenant_id: the tenant id
  58. :param provider: the provider dict
  59. """
  60. if isinstance(provider, dict) and "icon" in provider:
  61. provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
  62. provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
  63. )
  64. elif isinstance(provider, ToolProviderApiEntity):
  65. if provider.plugin_id:
  66. if isinstance(provider.icon, str):
  67. provider.icon = ToolTransformService.get_plugin_icon_url(
  68. tenant_id=tenant_id, filename=provider.icon
  69. )
  70. if isinstance(provider.icon_dark, str) and provider.icon_dark:
  71. provider.icon_dark = ToolTransformService.get_plugin_icon_url(
  72. tenant_id=tenant_id, filename=provider.icon_dark
  73. )
  74. else:
  75. provider.icon = ToolTransformService.get_tool_provider_icon_url(
  76. provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
  77. )
  78. if provider.icon_dark:
  79. provider.icon_dark = ToolTransformService.get_tool_provider_icon_url(
  80. provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon_dark
  81. )
  82. elif isinstance(provider, PluginDatasourceProviderEntity):
  83. if provider.plugin_id:
  84. if isinstance(provider.declaration.identity.icon, str):
  85. provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url(
  86. tenant_id=tenant_id, filename=provider.declaration.identity.icon
  87. )
  88. else:
  89. provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url(
  90. provider_type=provider.type.value,
  91. provider_name=provider.name,
  92. icon=provider.declaration.identity.icon,
  93. )
  94. @classmethod
  95. def builtin_provider_to_user_provider(
  96. cls,
  97. provider_controller: BuiltinToolProviderController | PluginToolProviderController,
  98. db_provider: Optional[BuiltinToolProvider],
  99. decrypt_credentials: bool = True,
  100. ) -> ToolProviderApiEntity:
  101. """
  102. convert provider controller to user provider
  103. """
  104. result = ToolProviderApiEntity(
  105. id=provider_controller.entity.identity.name,
  106. author=provider_controller.entity.identity.author,
  107. name=provider_controller.entity.identity.name,
  108. description=provider_controller.entity.identity.description,
  109. icon=provider_controller.entity.identity.icon,
  110. icon_dark=provider_controller.entity.identity.icon_dark,
  111. label=provider_controller.entity.identity.label,
  112. type=ToolProviderType.BUILT_IN,
  113. masked_credentials={},
  114. is_team_authorization=False,
  115. plugin_id=None,
  116. tools=[],
  117. labels=provider_controller.tool_labels,
  118. )
  119. if isinstance(provider_controller, PluginToolProviderController):
  120. result.plugin_id = provider_controller.plugin_id
  121. result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
  122. # get credentials schema
  123. schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
  124. for name, value in schema.items():
  125. if result.masked_credentials:
  126. result.masked_credentials[name] = ""
  127. # check if the provider need credentials
  128. if not provider_controller.need_credentials:
  129. result.is_team_authorization = True
  130. result.allow_delete = False
  131. elif db_provider:
  132. result.is_team_authorization = True
  133. if decrypt_credentials:
  134. credentials = db_provider.credentials
  135. # init tool configuration
  136. tool_configuration = ProviderConfigEncrypter(
  137. tenant_id=db_provider.tenant_id,
  138. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  139. provider_type=provider_controller.provider_type.value,
  140. provider_identity=provider_controller.entity.identity.name,
  141. )
  142. # decrypt the credentials and mask the credentials
  143. decrypted_credentials = tool_configuration.decrypt(data=credentials)
  144. masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
  145. result.masked_credentials = masked_credentials
  146. result.original_credentials = decrypted_credentials
  147. return result
  148. @staticmethod
  149. def api_provider_to_controller(
  150. db_provider: ApiToolProvider,
  151. ) -> ApiToolProviderController:
  152. """
  153. convert provider controller to user provider
  154. """
  155. # package tool provider controller
  156. auth_type = ApiProviderAuthType.NONE
  157. credentials_auth_type = db_provider.credentials.get("auth_type")
  158. if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
  159. auth_type = ApiProviderAuthType.API_KEY_HEADER
  160. elif credentials_auth_type == "api_key_query":
  161. auth_type = ApiProviderAuthType.API_KEY_QUERY
  162. controller = ApiToolProviderController.from_db(
  163. db_provider=db_provider,
  164. auth_type=auth_type,
  165. )
  166. return controller
  167. @staticmethod
  168. def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
  169. """
  170. convert provider controller to provider
  171. """
  172. return WorkflowToolProviderController.from_db(db_provider)
  173. @staticmethod
  174. def workflow_provider_to_user_provider(
  175. provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
  176. ):
  177. """
  178. convert provider controller to user provider
  179. """
  180. return ToolProviderApiEntity(
  181. id=provider_controller.provider_id,
  182. author=provider_controller.entity.identity.author,
  183. name=provider_controller.entity.identity.name,
  184. description=provider_controller.entity.identity.description,
  185. icon=provider_controller.entity.identity.icon,
  186. icon_dark=provider_controller.entity.identity.icon_dark,
  187. label=provider_controller.entity.identity.label,
  188. type=ToolProviderType.WORKFLOW,
  189. masked_credentials={},
  190. is_team_authorization=True,
  191. plugin_id=None,
  192. plugin_unique_identifier=None,
  193. tools=[],
  194. labels=labels or [],
  195. )
  196. @staticmethod
  197. def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
  198. user = db_provider.load_user()
  199. return ToolProviderApiEntity(
  200. id=db_provider.server_identifier if not for_list else db_provider.id,
  201. author=user.name if user else "Anonymous",
  202. name=db_provider.name,
  203. icon=db_provider.provider_icon,
  204. type=ToolProviderType.MCP,
  205. is_team_authorization=db_provider.authed,
  206. server_url=db_provider.masked_server_url,
  207. tools=ToolTransformService.mcp_tool_to_user_tool(
  208. db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
  209. ),
  210. updated_at=int(db_provider.updated_at.timestamp()),
  211. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  212. description=I18nObject(en_US="", zh_Hans=""),
  213. server_identifier=db_provider.server_identifier,
  214. )
  215. @staticmethod
  216. def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
  217. user = mcp_provider.load_user()
  218. return [
  219. ToolApiEntity(
  220. author=user.name if user else "Anonymous",
  221. name=tool.name,
  222. label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
  223. description=I18nObject(en_US=tool.description, zh_Hans=tool.description),
  224. parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
  225. labels=[],
  226. )
  227. for tool in tools
  228. ]
  229. @classmethod
  230. def api_provider_to_user_provider(
  231. cls,
  232. provider_controller: ApiToolProviderController,
  233. db_provider: ApiToolProvider,
  234. decrypt_credentials: bool = True,
  235. labels: list[str] | None = None,
  236. ) -> ToolProviderApiEntity:
  237. """
  238. convert provider controller to user provider
  239. """
  240. username = "Anonymous"
  241. if db_provider.user is None:
  242. raise ValueError(f"user is None for api provider {db_provider.id}")
  243. try:
  244. user = db_provider.user
  245. if not user:
  246. raise ValueError("user not found")
  247. username = user.name
  248. except Exception:
  249. logger.exception(f"failed to get user name for api provider {db_provider.id}")
  250. # add provider into providers
  251. credentials = db_provider.credentials
  252. result = ToolProviderApiEntity(
  253. id=db_provider.id,
  254. author=username,
  255. name=db_provider.name,
  256. description=I18nObject(
  257. en_US=db_provider.description,
  258. zh_Hans=db_provider.description,
  259. ),
  260. icon=db_provider.icon,
  261. label=I18nObject(
  262. en_US=db_provider.name,
  263. zh_Hans=db_provider.name,
  264. ),
  265. type=ToolProviderType.API,
  266. plugin_id=None,
  267. plugin_unique_identifier=None,
  268. masked_credentials={},
  269. is_team_authorization=True,
  270. tools=[],
  271. labels=labels or [],
  272. )
  273. if decrypt_credentials:
  274. # init tool configuration
  275. tool_configuration = ProviderConfigEncrypter(
  276. tenant_id=db_provider.tenant_id,
  277. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  278. provider_type=provider_controller.provider_type.value,
  279. provider_identity=provider_controller.entity.identity.name,
  280. )
  281. # decrypt the credentials and mask the credentials
  282. decrypted_credentials = tool_configuration.decrypt(data=credentials)
  283. masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
  284. result.masked_credentials = masked_credentials
  285. return result
  286. @staticmethod
  287. def convert_tool_entity_to_api_entity(
  288. tool: Union[ApiToolBundle, WorkflowTool, Tool],
  289. tenant_id: str,
  290. credentials: dict | None = None,
  291. labels: list[str] | None = None,
  292. ) -> ToolApiEntity:
  293. """
  294. convert tool to user tool
  295. """
  296. if isinstance(tool, Tool):
  297. # fork tool runtime
  298. tool = tool.fork_tool_runtime(
  299. runtime=ToolRuntime(
  300. credentials=credentials or {},
  301. tenant_id=tenant_id,
  302. )
  303. )
  304. # get tool parameters
  305. parameters = tool.entity.parameters or []
  306. # get tool runtime parameters
  307. runtime_parameters = tool.get_runtime_parameters()
  308. # override parameters
  309. current_parameters = parameters.copy()
  310. for runtime_parameter in runtime_parameters:
  311. found = False
  312. for index, parameter in enumerate(current_parameters):
  313. if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
  314. current_parameters[index] = runtime_parameter
  315. found = True
  316. break
  317. if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  318. current_parameters.append(runtime_parameter)
  319. return ToolApiEntity(
  320. author=tool.entity.identity.author,
  321. name=tool.entity.identity.name,
  322. label=tool.entity.identity.label,
  323. description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
  324. output_schema=tool.entity.output_schema,
  325. parameters=current_parameters,
  326. labels=labels or [],
  327. )
  328. if isinstance(tool, ApiToolBundle):
  329. return ToolApiEntity(
  330. author=tool.author,
  331. name=tool.operation_id or "",
  332. label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
  333. description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
  334. parameters=tool.parameters,
  335. labels=labels or [],
  336. )
  337. @staticmethod
  338. def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
  339. """
  340. Convert MCP JSON schema to tool parameters
  341. :param schema: JSON schema dictionary
  342. :return: list of ToolParameter instances
  343. """
  344. def create_parameter(
  345. name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
  346. ) -> ToolParameter:
  347. """Create a ToolParameter instance with given attributes"""
  348. input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
  349. return ToolParameter(
  350. name=name,
  351. llm_description=description,
  352. label=I18nObject(en_US=name),
  353. form=ToolParameter.ToolParameterForm.LLM,
  354. required=required,
  355. type=ToolParameter.ToolParameterType(param_type),
  356. human_description=I18nObject(en_US=description),
  357. **input_schema_dict,
  358. )
  359. def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
  360. """Process properties recursively"""
  361. TYPE_MAPPING = {"integer": "number", "float": "number"}
  362. COMPLEX_TYPES = ["array", "object"]
  363. parameters = []
  364. for name, prop in props.items():
  365. current_description = prop.get("description", "")
  366. prop_type = prop.get("type", "string")
  367. if isinstance(prop_type, list):
  368. prop_type = prop_type[0]
  369. if prop_type in TYPE_MAPPING:
  370. prop_type = TYPE_MAPPING[prop_type]
  371. input_schema = prop if prop_type in COMPLEX_TYPES else None
  372. parameters.append(
  373. create_parameter(name, current_description, prop_type, name in required, input_schema)
  374. )
  375. return parameters
  376. if schema.get("type") == "object" and "properties" in schema:
  377. return process_properties(schema["properties"], schema.get("required", []))
  378. return []