Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

tools_transform_service.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. import json
  2. import logging
  3. from typing import Any, Optional, Union, cast
  4. from yarl import URL
  5. from configs import dify_config
  6. from core.helper.provider_cache import ToolProviderCredentialsCache
  7. from core.mcp.types import Tool as MCPTool
  8. from core.tools.__base.tool import Tool
  9. from core.tools.__base.tool_runtime import ToolRuntime
  10. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  11. from core.tools.custom_tool.provider import ApiToolProviderController
  12. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
  13. from core.tools.entities.common_entities import I18nObject
  14. from core.tools.entities.tool_bundle import ApiToolBundle
  15. from core.tools.entities.tool_entities import (
  16. ApiProviderAuthType,
  17. CredentialType,
  18. ToolParameter,
  19. ToolProviderType,
  20. )
  21. from core.tools.plugin_tool.provider import PluginToolProviderController
  22. from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
  23. from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
  24. from core.tools.workflow_as_tool.tool import WorkflowTool
  25. from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
  26. logger = logging.getLogger(__name__)
  27. class ToolTransformService:
  28. @classmethod
  29. def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
  30. url_prefix = (
  31. URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
  32. )
  33. return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
  34. @classmethod
  35. def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
  36. """
  37. get tool provider icon url
  38. """
  39. url_prefix = (
  40. URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
  41. )
  42. if provider_type == ToolProviderType.BUILT_IN.value:
  43. return str(url_prefix / "builtin" / provider_name / "icon")
  44. elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
  45. try:
  46. if isinstance(icon, str):
  47. return cast(dict, json.loads(icon))
  48. return icon
  49. except Exception:
  50. return {"background": "#252525", "content": "\ud83d\ude01"}
  51. elif provider_type == ToolProviderType.MCP.value:
  52. return icon
  53. return ""
  54. @staticmethod
  55. def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]):
  56. """
  57. repack provider
  58. :param tenant_id: the tenant id
  59. :param provider: the provider dict
  60. """
  61. if isinstance(provider, dict) and "icon" in provider:
  62. provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
  63. provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
  64. )
  65. elif isinstance(provider, ToolProviderApiEntity):
  66. if provider.plugin_id:
  67. if isinstance(provider.icon, str):
  68. provider.icon = ToolTransformService.get_plugin_icon_url(
  69. tenant_id=tenant_id, filename=provider.icon
  70. )
  71. if isinstance(provider.icon_dark, str) and provider.icon_dark:
  72. provider.icon_dark = ToolTransformService.get_plugin_icon_url(
  73. tenant_id=tenant_id, filename=provider.icon_dark
  74. )
  75. else:
  76. provider.icon = ToolTransformService.get_tool_provider_icon_url(
  77. provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
  78. )
  79. if provider.icon_dark:
  80. provider.icon_dark = ToolTransformService.get_tool_provider_icon_url(
  81. provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon_dark
  82. )
  83. @classmethod
  84. def builtin_provider_to_user_provider(
  85. cls,
  86. provider_controller: BuiltinToolProviderController | PluginToolProviderController,
  87. db_provider: Optional[BuiltinToolProvider],
  88. decrypt_credentials: bool = True,
  89. ) -> ToolProviderApiEntity:
  90. """
  91. convert provider controller to user provider
  92. """
  93. result = ToolProviderApiEntity(
  94. id=provider_controller.entity.identity.name,
  95. author=provider_controller.entity.identity.author,
  96. name=provider_controller.entity.identity.name,
  97. description=provider_controller.entity.identity.description,
  98. icon=provider_controller.entity.identity.icon,
  99. icon_dark=provider_controller.entity.identity.icon_dark,
  100. label=provider_controller.entity.identity.label,
  101. type=ToolProviderType.BUILT_IN,
  102. masked_credentials={},
  103. is_team_authorization=False,
  104. plugin_id=None,
  105. tools=[],
  106. labels=provider_controller.tool_labels,
  107. )
  108. if isinstance(provider_controller, PluginToolProviderController):
  109. result.plugin_id = provider_controller.plugin_id
  110. result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
  111. # get credentials schema
  112. schema = {
  113. x.to_basic_provider_config().name: x
  114. for x in provider_controller.get_credentials_schema_by_type(
  115. CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
  116. )
  117. }
  118. for name, value in schema.items():
  119. if result.masked_credentials:
  120. result.masked_credentials[name] = ""
  121. # check if the provider need credentials
  122. if not provider_controller.need_credentials:
  123. result.is_team_authorization = True
  124. result.allow_delete = False
  125. elif db_provider:
  126. result.is_team_authorization = True
  127. if decrypt_credentials:
  128. credentials = db_provider.credentials
  129. # init tool configuration
  130. encrypter, _ = create_provider_encrypter(
  131. tenant_id=db_provider.tenant_id,
  132. config=[
  133. x.to_basic_provider_config()
  134. for x in provider_controller.get_credentials_schema_by_type(
  135. CredentialType.of(db_provider.credential_type)
  136. )
  137. ],
  138. cache=ToolProviderCredentialsCache(
  139. tenant_id=db_provider.tenant_id,
  140. provider=db_provider.provider,
  141. credential_id=db_provider.id,
  142. ),
  143. )
  144. # decrypt the credentials and mask the credentials
  145. decrypted_credentials = encrypter.decrypt(data=credentials)
  146. masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
  147. result.masked_credentials = masked_credentials
  148. result.original_credentials = decrypted_credentials
  149. return result
  150. @staticmethod
  151. def api_provider_to_controller(
  152. db_provider: ApiToolProvider,
  153. ) -> ApiToolProviderController:
  154. """
  155. convert provider controller to user provider
  156. """
  157. # package tool provider controller
  158. auth_type = ApiProviderAuthType.NONE
  159. credentials_auth_type = db_provider.credentials.get("auth_type")
  160. if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
  161. auth_type = ApiProviderAuthType.API_KEY_HEADER
  162. elif credentials_auth_type == "api_key_query":
  163. auth_type = ApiProviderAuthType.API_KEY_QUERY
  164. controller = ApiToolProviderController.from_db(
  165. db_provider=db_provider,
  166. auth_type=auth_type,
  167. )
  168. return controller
  169. @staticmethod
  170. def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
  171. """
  172. convert provider controller to provider
  173. """
  174. return WorkflowToolProviderController.from_db(db_provider)
  175. @staticmethod
  176. def workflow_provider_to_user_provider(
  177. provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
  178. ):
  179. """
  180. convert provider controller to user provider
  181. """
  182. return ToolProviderApiEntity(
  183. id=provider_controller.provider_id,
  184. author=provider_controller.entity.identity.author,
  185. name=provider_controller.entity.identity.name,
  186. description=provider_controller.entity.identity.description,
  187. icon=provider_controller.entity.identity.icon,
  188. icon_dark=provider_controller.entity.identity.icon_dark,
  189. label=provider_controller.entity.identity.label,
  190. type=ToolProviderType.WORKFLOW,
  191. masked_credentials={},
  192. is_team_authorization=True,
  193. plugin_id=None,
  194. plugin_unique_identifier=None,
  195. tools=[],
  196. labels=labels or [],
  197. )
  198. @staticmethod
  199. def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
  200. user = db_provider.load_user()
  201. return ToolProviderApiEntity(
  202. id=db_provider.server_identifier if not for_list else db_provider.id,
  203. author=user.name if user else "Anonymous",
  204. name=db_provider.name,
  205. icon=db_provider.provider_icon,
  206. type=ToolProviderType.MCP,
  207. is_team_authorization=db_provider.authed,
  208. server_url=db_provider.masked_server_url,
  209. tools=ToolTransformService.mcp_tool_to_user_tool(
  210. db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
  211. ),
  212. updated_at=int(db_provider.updated_at.timestamp()),
  213. label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
  214. description=I18nObject(en_US="", zh_Hans=""),
  215. server_identifier=db_provider.server_identifier,
  216. )
  217. @staticmethod
  218. def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
  219. user = mcp_provider.load_user()
  220. return [
  221. ToolApiEntity(
  222. author=user.name if user else "Anonymous",
  223. name=tool.name,
  224. label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
  225. description=I18nObject(en_US=tool.description, zh_Hans=tool.description),
  226. parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
  227. labels=[],
  228. )
  229. for tool in tools
  230. ]
  231. @classmethod
  232. def api_provider_to_user_provider(
  233. cls,
  234. provider_controller: ApiToolProviderController,
  235. db_provider: ApiToolProvider,
  236. decrypt_credentials: bool = True,
  237. labels: list[str] | None = None,
  238. ) -> ToolProviderApiEntity:
  239. """
  240. convert provider controller to user provider
  241. """
  242. username = "Anonymous"
  243. if db_provider.user is None:
  244. raise ValueError(f"user is None for api provider {db_provider.id}")
  245. try:
  246. user = db_provider.user
  247. if not user:
  248. raise ValueError("user not found")
  249. username = user.name
  250. except Exception:
  251. logger.exception(f"failed to get user name for api provider {db_provider.id}")
  252. # add provider into providers
  253. credentials = db_provider.credentials
  254. result = ToolProviderApiEntity(
  255. id=db_provider.id,
  256. author=username,
  257. name=db_provider.name,
  258. description=I18nObject(
  259. en_US=db_provider.description,
  260. zh_Hans=db_provider.description,
  261. ),
  262. icon=db_provider.icon,
  263. label=I18nObject(
  264. en_US=db_provider.name,
  265. zh_Hans=db_provider.name,
  266. ),
  267. type=ToolProviderType.API,
  268. plugin_id=None,
  269. plugin_unique_identifier=None,
  270. masked_credentials={},
  271. is_team_authorization=True,
  272. tools=[],
  273. labels=labels or [],
  274. )
  275. if decrypt_credentials:
  276. # init tool configuration
  277. encrypter, _ = create_tool_provider_encrypter(
  278. tenant_id=db_provider.tenant_id,
  279. controller=provider_controller,
  280. )
  281. # decrypt the credentials and mask the credentials
  282. decrypted_credentials = encrypter.decrypt(data=credentials)
  283. masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
  284. result.masked_credentials = masked_credentials
  285. return result
  286. @staticmethod
  287. def convert_tool_entity_to_api_entity(
  288. tool: Union[ApiToolBundle, WorkflowTool, Tool],
  289. tenant_id: str,
  290. labels: list[str] | None = None,
  291. ) -> ToolApiEntity:
  292. """
  293. convert tool to user tool
  294. """
  295. if isinstance(tool, Tool):
  296. # fork tool runtime
  297. tool = tool.fork_tool_runtime(
  298. runtime=ToolRuntime(
  299. credentials={},
  300. tenant_id=tenant_id,
  301. )
  302. )
  303. # get tool parameters
  304. base_parameters = tool.entity.parameters or []
  305. # get tool runtime parameters
  306. runtime_parameters = tool.get_runtime_parameters()
  307. # merge parameters using a functional approach to avoid type issues
  308. merged_parameters: list[ToolParameter] = []
  309. # create a mapping of runtime parameters for quick lookup
  310. runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters}
  311. # process base parameters, replacing with runtime versions if they exist
  312. for base_param in base_parameters:
  313. key = (base_param.name, base_param.form)
  314. if key in runtime_param_map:
  315. merged_parameters.append(runtime_param_map[key])
  316. else:
  317. merged_parameters.append(base_param)
  318. # add any runtime parameters that weren't in base parameters
  319. for runtime_parameter in runtime_parameters:
  320. if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  321. # check if this parameter is already in merged_parameters
  322. already_exists = any(
  323. p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters
  324. )
  325. if not already_exists:
  326. merged_parameters.append(runtime_parameter)
  327. return ToolApiEntity(
  328. author=tool.entity.identity.author,
  329. name=tool.entity.identity.name,
  330. label=tool.entity.identity.label,
  331. description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
  332. output_schema=tool.entity.output_schema,
  333. parameters=merged_parameters,
  334. labels=labels or [],
  335. )
  336. elif isinstance(tool, ApiToolBundle):
  337. return ToolApiEntity(
  338. author=tool.author,
  339. name=tool.operation_id or "",
  340. label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
  341. description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
  342. parameters=tool.parameters,
  343. labels=labels or [],
  344. )
  345. else:
  346. # Handle WorkflowTool case
  347. raise ValueError(f"Unsupported tool type: {type(tool)}")
  348. @staticmethod
  349. def convert_builtin_provider_to_credential_entity(
  350. provider: BuiltinToolProvider, credentials: dict
  351. ) -> ToolProviderCredentialApiEntity:
  352. return ToolProviderCredentialApiEntity(
  353. id=provider.id,
  354. name=provider.name,
  355. provider=provider.provider,
  356. credential_type=CredentialType.of(provider.credential_type),
  357. is_default=provider.is_default,
  358. credentials=credentials,
  359. )
  360. @staticmethod
  361. def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
  362. """
  363. Convert MCP JSON schema to tool parameters
  364. :param schema: JSON schema dictionary
  365. :return: list of ToolParameter instances
  366. """
  367. def create_parameter(
  368. name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
  369. ) -> ToolParameter:
  370. """Create a ToolParameter instance with given attributes"""
  371. input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
  372. return ToolParameter(
  373. name=name,
  374. llm_description=description,
  375. label=I18nObject(en_US=name),
  376. form=ToolParameter.ToolParameterForm.LLM,
  377. required=required,
  378. type=ToolParameter.ToolParameterType(param_type),
  379. human_description=I18nObject(en_US=description),
  380. **input_schema_dict,
  381. )
  382. def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
  383. """Process properties recursively"""
  384. TYPE_MAPPING = {"integer": "number", "float": "number"}
  385. COMPLEX_TYPES = ["array", "object"]
  386. parameters = []
  387. for name, prop in props.items():
  388. current_description = prop.get("description", "")
  389. prop_type = prop.get("type", "string")
  390. if isinstance(prop_type, list):
  391. prop_type = prop_type[0]
  392. if prop_type in TYPE_MAPPING:
  393. prop_type = TYPE_MAPPING[prop_type]
  394. input_schema = prop if prop_type in COMPLEX_TYPES else None
  395. parameters.append(
  396. create_parameter(name, current_description, prop_type, name in required, input_schema)
  397. )
  398. return parameters
  399. if schema.get("type") == "object" and "properties" in schema:
  400. return process_properties(schema["properties"], schema.get("required", []))
  401. return []