Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

tools_transform_service.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. import json
  2. import logging
  3. from typing import Any, Optional, Union, cast
  4. from yarl import URL
  5. from configs import dify_config
  6. from core.mcp.types import Tool as MCPTool
  7. from core.tools.__base.tool import Tool
  8. from core.tools.__base.tool_runtime import ToolRuntime
  9. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  10. from core.tools.custom_tool.provider import ApiToolProviderController
  11. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
  12. from core.tools.entities.common_entities import I18nObject
  13. from core.tools.entities.tool_bundle import ApiToolBundle
  14. from core.tools.entities.tool_entities import (
  15. ApiProviderAuthType,
  16. ToolParameter,
  17. ToolProviderType,
  18. )
  19. from core.tools.plugin_tool.provider import PluginToolProviderController
  20. from core.tools.utils.configuration import ProviderConfigEncrypter
  21. from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
  22. from core.tools.workflow_as_tool.tool import WorkflowTool
  23. from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
  24. logger = logging.getLogger(__name__)
  25. class ToolTransformService:
  26. @classmethod
  27. def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
  28. url_prefix = (
  29. URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
  30. )
  31. return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
  32. @classmethod
  33. def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
  34. """
  35. get tool provider icon url
  36. """
  37. url_prefix = (
  38. URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
  39. )
  40. if provider_type == ToolProviderType.BUILT_IN.value:
  41. return str(url_prefix / "builtin" / provider_name / "icon")
  42. elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
  43. try:
  44. if isinstance(icon, str):
  45. return cast(dict, json.loads(icon))
  46. return icon
  47. except Exception:
  48. return {"background": "#252525", "content": "\ud83d\ude01"}
  49. elif provider_type == ToolProviderType.MCP.value:
  50. return icon
  51. return ""
  52. @staticmethod
  53. def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]):
  54. """
  55. repack provider
  56. :param tenant_id: the tenant id
  57. :param provider: the provider dict
  58. """
  59. if isinstance(provider, dict) and "icon" in provider:
  60. provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
  61. provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
  62. )
  63. elif isinstance(provider, ToolProviderApiEntity):
  64. if provider.plugin_id:
  65. if isinstance(provider.icon, str):
  66. provider.icon = ToolTransformService.get_plugin_icon_url(
  67. tenant_id=tenant_id, filename=provider.icon
  68. )
  69. if isinstance(provider.icon_dark, str) and provider.icon_dark:
  70. provider.icon_dark = ToolTransformService.get_plugin_icon_url(
  71. tenant_id=tenant_id, filename=provider.icon_dark
  72. )
  73. else:
  74. provider.icon = ToolTransformService.get_tool_provider_icon_url(
  75. provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
  76. )
  77. if provider.icon_dark:
  78. provider.icon_dark = ToolTransformService.get_tool_provider_icon_url(
  79. provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon_dark
  80. )
  81. @classmethod
  82. def builtin_provider_to_user_provider(
  83. cls,
  84. provider_controller: BuiltinToolProviderController | PluginToolProviderController,
  85. db_provider: Optional[BuiltinToolProvider],
  86. decrypt_credentials: bool = True,
  87. ) -> ToolProviderApiEntity:
  88. """
  89. convert provider controller to user provider
  90. """
  91. result = ToolProviderApiEntity(
  92. id=provider_controller.entity.identity.name,
  93. author=provider_controller.entity.identity.author,
  94. name=provider_controller.entity.identity.name,
  95. description=provider_controller.entity.identity.description,
  96. icon=provider_controller.entity.identity.icon,
  97. icon_dark=provider_controller.entity.identity.icon_dark,
  98. label=provider_controller.entity.identity.label,
  99. type=ToolProviderType.BUILT_IN,
  100. masked_credentials={},
  101. is_team_authorization=False,
  102. plugin_id=None,
  103. tools=[],
  104. labels=provider_controller.tool_labels,
  105. )
  106. if isinstance(provider_controller, PluginToolProviderController):
  107. result.plugin_id = provider_controller.plugin_id
  108. result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
  109. # get credentials schema
  110. schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
  111. for name, value in schema.items():
  112. if result.masked_credentials:
  113. result.masked_credentials[name] = ""
  114. # check if the provider need credentials
  115. if not provider_controller.need_credentials:
  116. result.is_team_authorization = True
  117. result.allow_delete = False
  118. elif db_provider:
  119. result.is_team_authorization = True
  120. if decrypt_credentials:
  121. credentials = db_provider.credentials
  122. # init tool configuration
  123. tool_configuration = ProviderConfigEncrypter(
  124. tenant_id=db_provider.tenant_id,
  125. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  126. provider_type=provider_controller.provider_type.value,
  127. provider_identity=provider_controller.entity.identity.name,
  128. )
  129. # decrypt the credentials and mask the credentials
  130. decrypted_credentials = tool_configuration.decrypt(data=credentials)
  131. masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
  132. result.masked_credentials = masked_credentials
  133. result.original_credentials = decrypted_credentials
  134. return result
  135. @staticmethod
  136. def api_provider_to_controller(
  137. db_provider: ApiToolProvider,
  138. ) -> ApiToolProviderController:
  139. """
  140. convert provider controller to user provider
  141. """
  142. # package tool provider controller
  143. auth_type = ApiProviderAuthType.NONE
  144. credentials_auth_type = db_provider.credentials.get("auth_type")
  145. if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
  146. auth_type = ApiProviderAuthType.API_KEY_HEADER
  147. elif credentials_auth_type == "api_key_query":
  148. auth_type = ApiProviderAuthType.API_KEY_QUERY
  149. controller = ApiToolProviderController.from_db(
  150. db_provider=db_provider,
  151. auth_type=auth_type,
  152. )
  153. return controller
  154. @staticmethod
  155. def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
  156. """
  157. convert provider controller to provider
  158. """
  159. return WorkflowToolProviderController.from_db(db_provider)
  160. @staticmethod
  161. def workflow_provider_to_user_provider(
  162. provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
  163. ):
  164. """
  165. convert provider controller to user provider
  166. """
  167. return ToolProviderApiEntity(
  168. id=provider_controller.provider_id,
  169. author=provider_controller.entity.identity.author,
  170. name=provider_controller.entity.identity.name,
  171. description=provider_controller.entity.identity.description,
  172. icon=provider_controller.entity.identity.icon,
  173. icon_dark=provider_controller.entity.identity.icon_dark,
  174. label=provider_controller.entity.identity.label,
  175. type=ToolProviderType.WORKFLOW,
  176. masked_credentials={},
  177. is_team_authorization=True,
  178. plugin_id=None,
  179. plugin_unique_identifier=None,
  180. tools=[],
  181. labels=labels or [],
  182. )
  183. @staticmethod
  184. def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
  185. user = db_provider.load_user()
  186. return ToolProviderApiEntity(
  187. id=db_provider.server_identifier if not for_list else db_provider.id,
  188. author=user.name if user else "Anonymous",
  189. name=db_provider.name,
  190. icon=db_provider.provider_icon,
  191. type=ToolProviderType.MCP,
  192. is_team_authorization=db_provider.authed,
  193. server_url=db_provider.masked_server_url,
  194. tools=ToolTransformService.mcp_tool_to_user_tool(
  195. db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
  196. ),
  197. updated_at=int(db_provider.updated_at.timestamp()),
  198. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  199. description=I18nObject(en_US="", zh_Hans=""),
  200. server_identifier=db_provider.server_identifier,
  201. )
  202. @staticmethod
  203. def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
  204. user = mcp_provider.load_user()
  205. return [
  206. ToolApiEntity(
  207. author=user.name if user else "Anonymous",
  208. name=tool.name,
  209. label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
  210. description=I18nObject(en_US=tool.description, zh_Hans=tool.description),
  211. parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
  212. labels=[],
  213. )
  214. for tool in tools
  215. ]
  216. @classmethod
  217. def api_provider_to_user_provider(
  218. cls,
  219. provider_controller: ApiToolProviderController,
  220. db_provider: ApiToolProvider,
  221. decrypt_credentials: bool = True,
  222. labels: list[str] | None = None,
  223. ) -> ToolProviderApiEntity:
  224. """
  225. convert provider controller to user provider
  226. """
  227. username = "Anonymous"
  228. if db_provider.user is None:
  229. raise ValueError(f"user is None for api provider {db_provider.id}")
  230. try:
  231. user = db_provider.user
  232. if not user:
  233. raise ValueError("user not found")
  234. username = user.name
  235. except Exception:
  236. logger.exception(f"failed to get user name for api provider {db_provider.id}")
  237. # add provider into providers
  238. credentials = db_provider.credentials
  239. result = ToolProviderApiEntity(
  240. id=db_provider.id,
  241. author=username,
  242. name=db_provider.name,
  243. description=I18nObject(
  244. en_US=db_provider.description,
  245. zh_Hans=db_provider.description,
  246. ),
  247. icon=db_provider.icon,
  248. label=I18nObject(
  249. en_US=db_provider.name,
  250. zh_Hans=db_provider.name,
  251. ),
  252. type=ToolProviderType.API,
  253. plugin_id=None,
  254. plugin_unique_identifier=None,
  255. masked_credentials={},
  256. is_team_authorization=True,
  257. tools=[],
  258. labels=labels or [],
  259. )
  260. if decrypt_credentials:
  261. # init tool configuration
  262. tool_configuration = ProviderConfigEncrypter(
  263. tenant_id=db_provider.tenant_id,
  264. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  265. provider_type=provider_controller.provider_type.value,
  266. provider_identity=provider_controller.entity.identity.name,
  267. )
  268. # decrypt the credentials and mask the credentials
  269. decrypted_credentials = tool_configuration.decrypt(data=credentials)
  270. masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
  271. result.masked_credentials = masked_credentials
  272. return result
  273. @staticmethod
  274. def convert_tool_entity_to_api_entity(
  275. tool: Union[ApiToolBundle, WorkflowTool, Tool],
  276. tenant_id: str,
  277. credentials: dict | None = None,
  278. labels: list[str] | None = None,
  279. ) -> ToolApiEntity:
  280. """
  281. convert tool to user tool
  282. """
  283. if isinstance(tool, Tool):
  284. # fork tool runtime
  285. tool = tool.fork_tool_runtime(
  286. runtime=ToolRuntime(
  287. credentials=credentials or {},
  288. tenant_id=tenant_id,
  289. )
  290. )
  291. # get tool parameters
  292. parameters = tool.entity.parameters or []
  293. # get tool runtime parameters
  294. runtime_parameters = tool.get_runtime_parameters()
  295. # override parameters
  296. current_parameters = parameters.copy()
  297. for runtime_parameter in runtime_parameters:
  298. found = False
  299. for index, parameter in enumerate(current_parameters):
  300. if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
  301. current_parameters[index] = runtime_parameter
  302. found = True
  303. break
  304. if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  305. current_parameters.append(runtime_parameter)
  306. return ToolApiEntity(
  307. author=tool.entity.identity.author,
  308. name=tool.entity.identity.name,
  309. label=tool.entity.identity.label,
  310. description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
  311. output_schema=tool.entity.output_schema,
  312. parameters=current_parameters,
  313. labels=labels or [],
  314. )
  315. if isinstance(tool, ApiToolBundle):
  316. return ToolApiEntity(
  317. author=tool.author,
  318. name=tool.operation_id or "",
  319. label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
  320. description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
  321. parameters=tool.parameters,
  322. labels=labels or [],
  323. )
  324. @staticmethod
  325. def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
  326. """
  327. Convert MCP JSON schema to tool parameters
  328. :param schema: JSON schema dictionary
  329. :return: list of ToolParameter instances
  330. """
  331. def create_parameter(
  332. name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
  333. ) -> ToolParameter:
  334. """Create a ToolParameter instance with given attributes"""
  335. input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
  336. return ToolParameter(
  337. name=name,
  338. llm_description=description,
  339. label=I18nObject(en_US=name),
  340. form=ToolParameter.ToolParameterForm.LLM,
  341. required=required,
  342. type=ToolParameter.ToolParameterType(param_type),
  343. human_description=I18nObject(en_US=description),
  344. **input_schema_dict,
  345. )
  346. def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
  347. """Process properties recursively"""
  348. TYPE_MAPPING = {"integer": "number", "float": "number"}
  349. COMPLEX_TYPES = ["array", "object"]
  350. parameters = []
  351. for name, prop in props.items():
  352. current_description = prop.get("description", "")
  353. prop_type = prop.get("type", "string")
  354. if isinstance(prop_type, list):
  355. prop_type = prop_type[0]
  356. if prop_type in TYPE_MAPPING:
  357. prop_type = TYPE_MAPPING[prop_type]
  358. input_schema = prop if prop_type in COMPLEX_TYPES else None
  359. parameters.append(
  360. create_parameter(name, current_description, prop_type, name in required, input_schema)
  361. )
  362. return parameters
  363. if schema.get("type") == "object" and "properties" in schema:
  364. return process_properties(schema["properties"], schema.get("required", []))
  365. return []