您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

builtin_tools_manage_service.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import json
  2. import logging
  3. from pathlib import Path
  4. from sqlalchemy.orm import Session
  5. from configs import dify_config
  6. from core.helper.position_helper import is_filtered
  7. from core.model_runtime.utils.encoders import jsonable_encoder
  8. from core.plugin.entities.plugin import ToolProviderID
  9. from core.plugin.impl.exc import PluginDaemonClientSideError
  10. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  11. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
  12. from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
  13. from core.tools.tool_label_manager import ToolLabelManager
  14. from core.tools.tool_manager import ToolManager
  15. from core.tools.utils.configuration import ProviderConfigEncrypter
  16. from extensions.ext_database import db
  17. from models.tools import BuiltinToolProvider
  18. from services.tools.tools_transform_service import ToolTransformService
  19. logger = logging.getLogger(__name__)
  20. class BuiltinToolManageService:
  21. @staticmethod
  22. def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
  23. """
  24. list builtin tool provider tools
  25. :param tenant_id: the id of the tenant
  26. :param provider: the name of the provider
  27. :return: the list of tools
  28. """
  29. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  30. tools = provider_controller.get_tools()
  31. tool_provider_configurations = ProviderConfigEncrypter(
  32. tenant_id=tenant_id,
  33. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  34. provider_type=provider_controller.provider_type.value,
  35. provider_identity=provider_controller.entity.identity.name,
  36. )
  37. # check if user has added the provider
  38. builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  39. credentials = {}
  40. if builtin_provider is not None:
  41. # get credentials
  42. credentials = builtin_provider.credentials
  43. credentials = tool_provider_configurations.decrypt(credentials)
  44. result: list[ToolApiEntity] = []
  45. for tool in tools or []:
  46. result.append(
  47. ToolTransformService.convert_tool_entity_to_api_entity(
  48. tool=tool,
  49. credentials=credentials,
  50. tenant_id=tenant_id,
  51. labels=ToolLabelManager.get_tool_labels(provider_controller),
  52. )
  53. )
  54. return result
  55. @staticmethod
  56. def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
  57. """
  58. get builtin tool provider info
  59. """
  60. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  61. tool_provider_configurations = ProviderConfigEncrypter(
  62. tenant_id=tenant_id,
  63. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  64. provider_type=provider_controller.provider_type.value,
  65. provider_identity=provider_controller.entity.identity.name,
  66. )
  67. # check if user has added the provider
  68. builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  69. credentials = {}
  70. if builtin_provider is not None:
  71. # get credentials
  72. credentials = builtin_provider.credentials
  73. credentials = tool_provider_configurations.decrypt(credentials)
  74. entity = ToolTransformService.builtin_provider_to_user_provider(
  75. provider_controller=provider_controller,
  76. db_provider=builtin_provider,
  77. decrypt_credentials=True,
  78. )
  79. entity.original_credentials = {}
  80. return entity
  81. @staticmethod
  82. def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
  83. """
  84. list builtin provider credentials schema
  85. :param provider_name: the name of the provider
  86. :param tenant_id: the id of the tenant
  87. :return: the list of tool providers
  88. """
  89. provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
  90. return jsonable_encoder(provider.get_credentials_schema())
  91. @staticmethod
  92. def update_builtin_tool_provider(
  93. session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
  94. ):
  95. """
  96. update builtin tool provider
  97. """
  98. # get if the provider exists
  99. provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  100. try:
  101. # get provider
  102. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  103. if not provider_controller.need_credentials:
  104. raise ValueError(f"provider {provider_name} does not need credentials")
  105. tool_configuration = ProviderConfigEncrypter(
  106. tenant_id=tenant_id,
  107. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  108. provider_type=provider_controller.provider_type.value,
  109. provider_identity=provider_controller.entity.identity.name,
  110. )
  111. # get original credentials if exists
  112. if provider is not None:
  113. original_credentials = tool_configuration.decrypt(provider.credentials)
  114. masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
  115. # check if the credential has changed, save the original credential
  116. for name, value in credentials.items():
  117. if name in masked_credentials and value == masked_credentials[name]:
  118. credentials[name] = original_credentials[name]
  119. # validate credentials
  120. provider_controller.validate_credentials(user_id, credentials)
  121. # encrypt credentials
  122. credentials = tool_configuration.encrypt(credentials)
  123. except (
  124. PluginDaemonClientSideError,
  125. ToolProviderNotFoundError,
  126. ToolNotFoundError,
  127. ToolProviderCredentialValidationError,
  128. ) as e:
  129. raise ValueError(str(e))
  130. if provider is None:
  131. # create provider
  132. provider = BuiltinToolProvider(
  133. tenant_id=tenant_id,
  134. user_id=user_id,
  135. provider=provider_name,
  136. encrypted_credentials=json.dumps(credentials),
  137. )
  138. db.session.add(provider)
  139. else:
  140. provider.encrypted_credentials = json.dumps(credentials)
  141. # delete cache
  142. tool_configuration.delete_tool_credentials_cache()
  143. db.session.commit()
  144. return {"result": "success"}
  145. @staticmethod
  146. def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
  147. """
  148. get builtin tool provider credentials
  149. """
  150. provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  151. if provider_obj is None:
  152. return {}
  153. provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
  154. tool_configuration = ProviderConfigEncrypter(
  155. tenant_id=tenant_id,
  156. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  157. provider_type=provider_controller.provider_type.value,
  158. provider_identity=provider_controller.entity.identity.name,
  159. )
  160. credentials = tool_configuration.decrypt(provider_obj.credentials)
  161. credentials = tool_configuration.mask_tool_credentials(credentials)
  162. return credentials
  163. @staticmethod
  164. def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
  165. """
  166. delete tool provider
  167. """
  168. provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  169. if provider_obj is None:
  170. raise ValueError(f"you have not added provider {provider_name}")
  171. db.session.delete(provider_obj)
  172. db.session.commit()
  173. # delete cache
  174. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  175. tool_configuration = ProviderConfigEncrypter(
  176. tenant_id=tenant_id,
  177. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  178. provider_type=provider_controller.provider_type.value,
  179. provider_identity=provider_controller.entity.identity.name,
  180. )
  181. tool_configuration.delete_tool_credentials_cache()
  182. return {"result": "success"}
  183. @staticmethod
  184. def get_builtin_tool_provider_icon(provider: str):
  185. """
  186. get tool provider icon and it's mimetype
  187. """
  188. icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
  189. icon_bytes = Path(icon_path).read_bytes()
  190. return icon_bytes, mime_type
  191. @staticmethod
  192. def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
  193. """
  194. list builtin tools
  195. """
  196. # get all builtin providers
  197. provider_controllers = ToolManager.list_builtin_providers(tenant_id)
  198. with db.session.no_autoflush:
  199. # get all user added providers
  200. db_providers: list[BuiltinToolProvider] = (
  201. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
  202. )
  203. # rewrite db_providers
  204. for db_provider in db_providers:
  205. db_provider.provider = str(ToolProviderID(db_provider.provider))
  206. # find provider
  207. def find_provider(provider):
  208. return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
  209. result: list[ToolProviderApiEntity] = []
  210. for provider_controller in provider_controllers:
  211. try:
  212. # handle include, exclude
  213. if is_filtered(
  214. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  215. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  216. data=provider_controller,
  217. name_func=lambda x: x.identity.name,
  218. ):
  219. continue
  220. # convert provider controller to user provider
  221. user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
  222. provider_controller=provider_controller,
  223. db_provider=find_provider(provider_controller.entity.identity.name),
  224. decrypt_credentials=True,
  225. )
  226. # add icon
  227. ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
  228. tools = provider_controller.get_tools()
  229. for tool in tools or []:
  230. user_builtin_provider.tools.append(
  231. ToolTransformService.convert_tool_entity_to_api_entity(
  232. tenant_id=tenant_id,
  233. tool=tool,
  234. credentials=user_builtin_provider.original_credentials,
  235. labels=ToolLabelManager.get_tool_labels(provider_controller),
  236. )
  237. )
  238. result.append(user_builtin_provider)
  239. except Exception as e:
  240. raise e
  241. return BuiltinToolProviderSort.sort(result)
  242. @staticmethod
  243. def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
  244. try:
  245. full_provider_name = provider_name
  246. provider_id_entity = ToolProviderID(provider_name)
  247. provider_name = provider_id_entity.provider_name
  248. if provider_id_entity.organization != "langgenius":
  249. provider_obj = (
  250. db.session.query(BuiltinToolProvider)
  251. .filter(
  252. BuiltinToolProvider.tenant_id == tenant_id,
  253. BuiltinToolProvider.provider == full_provider_name,
  254. )
  255. .first()
  256. )
  257. else:
  258. provider_obj = (
  259. db.session.query(BuiltinToolProvider)
  260. .filter(
  261. BuiltinToolProvider.tenant_id == tenant_id,
  262. (BuiltinToolProvider.provider == provider_name)
  263. | (BuiltinToolProvider.provider == full_provider_name),
  264. )
  265. .first()
  266. )
  267. if provider_obj is None:
  268. return None
  269. provider_obj.provider = ToolProviderID(provider_obj.provider).to_string()
  270. return provider_obj
  271. except Exception:
  272. # it's an old provider without organization
  273. return (
  274. db.session.query(BuiltinToolProvider)
  275. .filter(
  276. BuiltinToolProvider.tenant_id == tenant_id,
  277. (BuiltinToolProvider.provider == provider_name),
  278. )
  279. .first()
  280. )