You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

builtin_tools_manage_service.py 30KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. import json
  2. import logging
  3. import re
  4. from collections.abc import Mapping
  5. from pathlib import Path
  6. from typing import Any, Optional
  7. from sqlalchemy.orm import Session
  8. from configs import dify_config
  9. from constants import HIDDEN_VALUE, UNKNOWN_VALUE
  10. from core.helper.position_helper import is_filtered
  11. from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
  12. from core.plugin.entities.plugin import ToolProviderID
  13. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  14. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  15. from core.tools.entities.api_entities import (
  16. ToolApiEntity,
  17. ToolProviderApiEntity,
  18. ToolProviderCredentialApiEntity,
  19. ToolProviderCredentialInfoApiEntity,
  20. )
  21. from core.tools.entities.tool_entities import CredentialType
  22. from core.tools.errors import ToolProviderNotFoundError
  23. from core.tools.plugin_tool.provider import PluginToolProviderController
  24. from core.tools.tool_label_manager import ToolLabelManager
  25. from core.tools.tool_manager import ToolManager
  26. from core.tools.utils.encryption import create_provider_encrypter
  27. from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
  28. from extensions.ext_database import db
  29. from extensions.ext_redis import redis_client
  30. from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
  31. from services.plugin.plugin_service import PluginService
  32. from services.tools.tools_transform_service import ToolTransformService
  33. logger = logging.getLogger(__name__)
  34. class BuiltinToolManageService:
  35. __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
  36. @staticmethod
  37. def delete_custom_oauth_client_params(tenant_id: str, provider: str):
  38. """
  39. delete custom oauth client params
  40. """
  41. tool_provider = ToolProviderID(provider)
  42. with Session(db.engine) as session:
  43. session.query(ToolOAuthTenantClient).filter_by(
  44. tenant_id=tenant_id,
  45. provider=tool_provider.provider_name,
  46. plugin_id=tool_provider.plugin_id,
  47. ).delete()
  48. session.commit()
  49. return {"result": "success"}
  50. @staticmethod
  51. def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str):
  52. """
  53. get builtin tool provider oauth client schema
  54. """
  55. provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
  56. verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
  57. tenant_id, provider.plugin_unique_identifier
  58. )
  59. is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled(
  60. tenant_id, provider_name
  61. )
  62. is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists(
  63. provider_name
  64. )
  65. result = {
  66. "schema": provider.get_oauth_client_schema(),
  67. "is_oauth_custom_client_enabled": is_oauth_custom_client_enabled,
  68. "is_system_oauth_params_exists": is_system_oauth_params_exists,
  69. "client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name),
  70. "redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback",
  71. }
  72. return result
  73. @staticmethod
  74. def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
  75. """
  76. list builtin tool provider tools
  77. :param tenant_id: the id of the tenant
  78. :param provider: the name of the provider
  79. :return: the list of tools
  80. """
  81. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  82. tools = provider_controller.get_tools()
  83. result: list[ToolApiEntity] = []
  84. for tool in tools or []:
  85. result.append(
  86. ToolTransformService.convert_tool_entity_to_api_entity(
  87. tool=tool,
  88. tenant_id=tenant_id,
  89. labels=ToolLabelManager.get_tool_labels(provider_controller),
  90. )
  91. )
  92. return result
  93. @staticmethod
  94. def get_builtin_tool_provider_info(tenant_id: str, provider: str):
  95. """
  96. get builtin tool provider info
  97. """
  98. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  99. # check if user has added the provider
  100. builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
  101. if builtin_provider is None:
  102. raise ValueError(f"you have not added provider {provider}")
  103. entity = ToolTransformService.builtin_provider_to_user_provider(
  104. provider_controller=provider_controller,
  105. db_provider=builtin_provider,
  106. decrypt_credentials=True,
  107. )
  108. entity.original_credentials = {}
  109. return entity
  110. @staticmethod
  111. def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
  112. """
  113. list builtin provider credentials schema
  114. :param credential_type: credential type
  115. :param provider_name: the name of the provider
  116. :param tenant_id: the id of the tenant
  117. :return: the list of tool providers
  118. """
  119. provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
  120. return provider.get_credentials_schema_by_type(credential_type)
  121. @staticmethod
  122. def update_builtin_tool_provider(
  123. user_id: str,
  124. tenant_id: str,
  125. provider: str,
  126. credential_id: str,
  127. credentials: dict | None = None,
  128. name: str | None = None,
  129. ):
  130. """
  131. update builtin tool provider
  132. """
  133. with Session(db.engine) as session:
  134. # get if the provider exists
  135. db_provider = (
  136. session.query(BuiltinToolProvider)
  137. .filter(
  138. BuiltinToolProvider.tenant_id == tenant_id,
  139. BuiltinToolProvider.id == credential_id,
  140. )
  141. .first()
  142. )
  143. if db_provider is None:
  144. raise ValueError(f"you have not added provider {provider}")
  145. try:
  146. if CredentialType.of(db_provider.credential_type).is_editable() and credentials:
  147. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  148. if not provider_controller.need_credentials:
  149. raise ValueError(f"provider {provider} does not need credentials")
  150. encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
  151. tenant_id, db_provider, provider, provider_controller
  152. )
  153. original_credentials = encrypter.decrypt(db_provider.credentials)
  154. new_credentials: dict = {
  155. key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
  156. for key, value in credentials.items()
  157. }
  158. if CredentialType.of(db_provider.credential_type).is_validate_allowed():
  159. provider_controller.validate_credentials(user_id, new_credentials)
  160. # encrypt credentials
  161. db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
  162. cache.delete()
  163. # update name if provided
  164. if name and name != db_provider.name:
  165. # check if the name is already used
  166. if (
  167. session.query(BuiltinToolProvider)
  168. .filter_by(tenant_id=tenant_id, provider=provider, name=name)
  169. .count()
  170. > 0
  171. ):
  172. raise ValueError(f"the credential name '{name}' is already used")
  173. db_provider.name = name
  174. session.commit()
  175. except Exception as e:
  176. session.rollback()
  177. raise ValueError(str(e))
  178. return {"result": "success"}
  179. @staticmethod
  180. def add_builtin_tool_provider(
  181. user_id: str,
  182. api_type: CredentialType,
  183. tenant_id: str,
  184. provider: str,
  185. credentials: dict,
  186. name: str | None = None,
  187. ):
  188. """
  189. add builtin tool provider
  190. """
  191. try:
  192. with Session(db.engine) as session:
  193. lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
  194. with redis_client.lock(lock, timeout=20):
  195. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  196. if not provider_controller.need_credentials:
  197. raise ValueError(f"provider {provider} does not need credentials")
  198. provider_count = (
  199. session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
  200. )
  201. # check if the provider count is reached the limit
  202. if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
  203. raise ValueError(f"you have reached the maximum number of providers for {provider}")
  204. # validate credentials if allowed
  205. if CredentialType.of(api_type).is_validate_allowed():
  206. provider_controller.validate_credentials(user_id, credentials)
  207. # generate name if not provided
  208. if name is None or name == "":
  209. name = BuiltinToolManageService.generate_builtin_tool_provider_name(
  210. session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
  211. )
  212. else:
  213. # check if the name is already used
  214. if (
  215. session.query(BuiltinToolProvider)
  216. .filter_by(tenant_id=tenant_id, provider=provider, name=name)
  217. .count()
  218. > 0
  219. ):
  220. raise ValueError(f"the credential name '{name}' is already used")
  221. # create encrypter
  222. encrypter, _ = create_provider_encrypter(
  223. tenant_id=tenant_id,
  224. config=[
  225. x.to_basic_provider_config()
  226. for x in provider_controller.get_credentials_schema_by_type(api_type)
  227. ],
  228. cache=NoOpProviderCredentialCache(),
  229. )
  230. db_provider = BuiltinToolProvider(
  231. tenant_id=tenant_id,
  232. user_id=user_id,
  233. provider=provider,
  234. encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
  235. credential_type=api_type.value,
  236. name=name,
  237. )
  238. session.add(db_provider)
  239. session.commit()
  240. except Exception as e:
  241. session.rollback()
  242. raise ValueError(str(e))
  243. return {"result": "success"}
  244. @staticmethod
  245. def create_tool_encrypter(
  246. tenant_id: str,
  247. db_provider: BuiltinToolProvider,
  248. provider: str,
  249. provider_controller: BuiltinToolProviderController,
  250. ):
  251. encrypter, cache = create_provider_encrypter(
  252. tenant_id=tenant_id,
  253. config=[
  254. x.to_basic_provider_config()
  255. for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
  256. ],
  257. cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
  258. )
  259. return encrypter, cache
  260. @staticmethod
  261. def generate_builtin_tool_provider_name(
  262. session: Session, tenant_id: str, provider: str, credential_type: CredentialType
  263. ) -> str:
  264. try:
  265. db_providers = (
  266. session.query(BuiltinToolProvider)
  267. .filter_by(
  268. tenant_id=tenant_id,
  269. provider=provider,
  270. credential_type=credential_type.value,
  271. )
  272. .order_by(BuiltinToolProvider.created_at.desc())
  273. .all()
  274. )
  275. # Get the default name pattern
  276. default_pattern = f"{credential_type.get_name()}"
  277. # Find all names that match the default pattern: "{default_pattern} {number}"
  278. pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
  279. numbers = []
  280. for db_provider in db_providers:
  281. if db_provider.name:
  282. match = re.match(pattern, db_provider.name.strip())
  283. if match:
  284. numbers.append(int(match.group(1)))
  285. # If no default pattern names found, start with 1
  286. if not numbers:
  287. return f"{default_pattern} 1"
  288. # Find the next number
  289. max_number = max(numbers)
  290. return f"{default_pattern} {max_number + 1}"
  291. except Exception as e:
  292. logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
  293. # fallback
  294. return f"{credential_type.get_name()} 1"
  295. @staticmethod
  296. def get_builtin_tool_provider_credentials(
  297. tenant_id: str, provider_name: str
  298. ) -> list[ToolProviderCredentialApiEntity]:
  299. """
  300. get builtin tool provider credentials
  301. """
  302. with db.session.no_autoflush:
  303. providers = (
  304. db.session.query(BuiltinToolProvider)
  305. .filter_by(tenant_id=tenant_id, provider=provider_name)
  306. .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
  307. .all()
  308. )
  309. if len(providers) == 0:
  310. return []
  311. default_provider = providers[0]
  312. default_provider.is_default = True
  313. provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
  314. credentials: list[ToolProviderCredentialApiEntity] = []
  315. encrypters = {}
  316. for provider in providers:
  317. credential_type = provider.credential_type
  318. if credential_type not in encrypters:
  319. encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter(
  320. tenant_id, provider, provider.provider, provider_controller
  321. )[0]
  322. encrypter = encrypters[credential_type]
  323. decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
  324. credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
  325. provider=provider,
  326. credentials=decrypt_credential,
  327. )
  328. credentials.append(credential_entity)
  329. return credentials
  330. @staticmethod
  331. def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity:
  332. """
  333. get builtin tool provider credential info
  334. """
  335. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  336. supported_credential_types = provider_controller.get_supported_credential_types()
  337. credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
  338. credential_info = ToolProviderCredentialInfoApiEntity(
  339. supported_credential_types=supported_credential_types,
  340. is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
  341. credentials=credentials,
  342. )
  343. return credential_info
  344. @staticmethod
  345. def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
  346. """
  347. delete tool provider
  348. """
  349. with Session(db.engine) as session:
  350. db_provider = (
  351. session.query(BuiltinToolProvider)
  352. .filter(
  353. BuiltinToolProvider.tenant_id == tenant_id,
  354. BuiltinToolProvider.id == credential_id,
  355. )
  356. .first()
  357. )
  358. if db_provider is None:
  359. raise ValueError(f"you have not added provider {provider}")
  360. session.delete(db_provider)
  361. session.commit()
  362. # delete cache
  363. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  364. _, cache = BuiltinToolManageService.create_tool_encrypter(
  365. tenant_id, db_provider, provider, provider_controller
  366. )
  367. cache.delete()
  368. return {"result": "success"}
  369. @staticmethod
  370. def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
  371. """
  372. set default provider
  373. """
  374. with Session(db.engine) as session:
  375. # get provider
  376. target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
  377. if target_provider is None:
  378. raise ValueError("provider not found")
  379. # clear default provider
  380. session.query(BuiltinToolProvider).filter_by(
  381. tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
  382. ).update({"is_default": False})
  383. # set new default provider
  384. target_provider.is_default = True
  385. session.commit()
  386. return {"result": "success"}
  387. @staticmethod
  388. def is_oauth_system_client_exists(provider_name: str) -> bool:
  389. """
  390. check if oauth system client exists
  391. """
  392. tool_provider = ToolProviderID(provider_name)
  393. with Session(db.engine).no_autoflush as session:
  394. system_client: ToolOAuthSystemClient | None = (
  395. session.query(ToolOAuthSystemClient)
  396. .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
  397. .first()
  398. )
  399. return system_client is not None
  400. @staticmethod
  401. def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
  402. """
  403. check if oauth custom client is enabled
  404. """
  405. tool_provider = ToolProviderID(provider)
  406. with Session(db.engine).no_autoflush as session:
  407. user_client: ToolOAuthTenantClient | None = (
  408. session.query(ToolOAuthTenantClient)
  409. .filter_by(
  410. tenant_id=tenant_id,
  411. provider=tool_provider.provider_name,
  412. plugin_id=tool_provider.plugin_id,
  413. enabled=True,
  414. )
  415. .first()
  416. )
  417. return user_client is not None and user_client.enabled
  418. @staticmethod
  419. def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None:
  420. """
  421. get builtin tool provider
  422. """
  423. tool_provider = ToolProviderID(provider)
  424. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  425. encrypter, _ = create_provider_encrypter(
  426. tenant_id=tenant_id,
  427. config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
  428. cache=NoOpProviderCredentialCache(),
  429. )
  430. with Session(db.engine).no_autoflush as session:
  431. user_client: ToolOAuthTenantClient | None = (
  432. session.query(ToolOAuthTenantClient)
  433. .filter_by(
  434. tenant_id=tenant_id,
  435. provider=tool_provider.provider_name,
  436. plugin_id=tool_provider.plugin_id,
  437. enabled=True,
  438. )
  439. .first()
  440. )
  441. oauth_params: Mapping[str, Any] | None = None
  442. if user_client:
  443. oauth_params = encrypter.decrypt(user_client.oauth_params)
  444. return oauth_params
  445. # only verified provider can use custom oauth client
  446. is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
  447. tenant_id, provider.plugin_unique_identifier
  448. )
  449. if not is_verified:
  450. return oauth_params
  451. system_client: ToolOAuthSystemClient | None = (
  452. session.query(ToolOAuthSystemClient)
  453. .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
  454. .first()
  455. )
  456. if system_client:
  457. try:
  458. oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
  459. except Exception as e:
  460. raise ValueError(f"Error decrypting system oauth params: {e}")
  461. return oauth_params
  462. @staticmethod
  463. def get_builtin_tool_provider_icon(provider: str):
  464. """
  465. get tool provider icon and it's mimetype
  466. """
  467. icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
  468. icon_bytes = Path(icon_path).read_bytes()
  469. return icon_bytes, mime_type
  470. @staticmethod
  471. def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
  472. """
  473. list builtin tools
  474. """
  475. # get all builtin providers
  476. provider_controllers = ToolManager.list_builtin_providers(tenant_id)
  477. with db.session.no_autoflush:
  478. # get all user added providers
  479. db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
  480. # rewrite db_providers
  481. for db_provider in db_providers:
  482. db_provider.provider = str(ToolProviderID(db_provider.provider))
  483. # find provider
  484. def find_provider(provider):
  485. return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
  486. result: list[ToolProviderApiEntity] = []
  487. for provider_controller in provider_controllers:
  488. try:
  489. # handle include, exclude
  490. if is_filtered(
  491. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  492. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  493. data=provider_controller,
  494. name_func=lambda x: x.identity.name,
  495. ):
  496. continue
  497. # convert provider controller to user provider
  498. user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
  499. provider_controller=provider_controller,
  500. db_provider=find_provider(provider_controller.entity.identity.name),
  501. decrypt_credentials=True,
  502. )
  503. # add icon
  504. ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
  505. tools = provider_controller.get_tools()
  506. for tool in tools or []:
  507. user_builtin_provider.tools.append(
  508. ToolTransformService.convert_tool_entity_to_api_entity(
  509. tenant_id=tenant_id,
  510. tool=tool,
  511. labels=ToolLabelManager.get_tool_labels(provider_controller),
  512. )
  513. )
  514. result.append(user_builtin_provider)
  515. except Exception as e:
  516. raise e
  517. return BuiltinToolProviderSort.sort(result)
  518. @staticmethod
  519. def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
  520. """
  521. This method is used to fetch the builtin provider from the database
  522. 1.if the default provider exists, return the default provider
  523. 2.if the default provider does not exist, return the oldest provider
  524. """
  525. with Session(db.engine) as session:
  526. try:
  527. full_provider_name = provider_name
  528. provider_id_entity = ToolProviderID(provider_name)
  529. provider_name = provider_id_entity.provider_name
  530. if provider_id_entity.organization != "langgenius":
  531. provider = (
  532. session.query(BuiltinToolProvider)
  533. .filter(
  534. BuiltinToolProvider.tenant_id == tenant_id,
  535. BuiltinToolProvider.provider == full_provider_name,
  536. )
  537. .order_by(
  538. BuiltinToolProvider.is_default.desc(), # default=True first
  539. BuiltinToolProvider.created_at.asc(), # oldest first
  540. )
  541. .first()
  542. )
  543. else:
  544. provider = (
  545. session.query(BuiltinToolProvider)
  546. .filter(
  547. BuiltinToolProvider.tenant_id == tenant_id,
  548. (BuiltinToolProvider.provider == provider_name)
  549. | (BuiltinToolProvider.provider == full_provider_name),
  550. )
  551. .order_by(
  552. BuiltinToolProvider.is_default.desc(), # default=True first
  553. BuiltinToolProvider.created_at.asc(), # oldest first
  554. )
  555. .first()
  556. )
  557. if provider is None:
  558. return None
  559. provider.provider = ToolProviderID(provider.provider).to_string()
  560. return provider
  561. except Exception:
  562. # it's an old provider without organization
  563. return (
  564. session.query(BuiltinToolProvider)
  565. .filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
  566. .order_by(
  567. BuiltinToolProvider.is_default.desc(), # default=True first
  568. BuiltinToolProvider.created_at.asc(), # oldest first
  569. )
  570. .first()
  571. )
  572. @staticmethod
  573. def save_custom_oauth_client_params(
  574. tenant_id: str,
  575. provider: str,
  576. client_params: Optional[dict] = None,
  577. enable_oauth_custom_client: Optional[bool] = None,
  578. ):
  579. """
  580. setup oauth custom client
  581. """
  582. if client_params is None and enable_oauth_custom_client is None:
  583. return {"result": "success"}
  584. tool_provider = ToolProviderID(provider)
  585. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  586. if not provider_controller:
  587. raise ToolProviderNotFoundError(f"Provider {provider} not found")
  588. if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
  589. raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
  590. with Session(db.engine) as session:
  591. custom_client_params = (
  592. session.query(ToolOAuthTenantClient)
  593. .filter_by(
  594. tenant_id=tenant_id,
  595. plugin_id=tool_provider.plugin_id,
  596. provider=tool_provider.provider_name,
  597. )
  598. .first()
  599. )
  600. # if the record does not exist, create a basic record
  601. if custom_client_params is None:
  602. custom_client_params = ToolOAuthTenantClient(
  603. tenant_id=tenant_id,
  604. plugin_id=tool_provider.plugin_id,
  605. provider=tool_provider.provider_name,
  606. )
  607. session.add(custom_client_params)
  608. if client_params is not None:
  609. encrypter, _ = create_provider_encrypter(
  610. tenant_id=tenant_id,
  611. config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
  612. cache=NoOpProviderCredentialCache(),
  613. )
  614. original_params = encrypter.decrypt(custom_client_params.oauth_params)
  615. new_params: dict = {
  616. key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
  617. for key, value in client_params.items()
  618. }
  619. custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
  620. if enable_oauth_custom_client is not None:
  621. custom_client_params.enabled = enable_oauth_custom_client
  622. session.commit()
  623. return {"result": "success"}
  624. @staticmethod
  625. def get_custom_oauth_client_params(tenant_id: str, provider: str):
  626. """
  627. get custom oauth client params
  628. """
  629. with Session(db.engine) as session:
  630. tool_provider = ToolProviderID(provider)
  631. custom_oauth_client_params: ToolOAuthTenantClient | None = (
  632. session.query(ToolOAuthTenantClient)
  633. .filter_by(
  634. tenant_id=tenant_id,
  635. plugin_id=tool_provider.plugin_id,
  636. provider=tool_provider.provider_name,
  637. )
  638. .first()
  639. )
  640. if custom_oauth_client_params is None:
  641. return {}
  642. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  643. if not provider_controller:
  644. raise ToolProviderNotFoundError(f"Provider {provider} not found")
  645. if not isinstance(provider_controller, BuiltinToolProviderController):
  646. raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
  647. encrypter, _ = create_provider_encrypter(
  648. tenant_id=tenant_id,
  649. config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
  650. cache=NoOpProviderCredentialCache(),
  651. )
  652. return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))