You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

builtin_tools_manage_service.py 31KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756
  1. import json
  2. import logging
  3. import re
  4. from collections.abc import Mapping
  5. from pathlib import Path
  6. from typing import Any, Optional
  7. from sqlalchemy import exists, select
  8. from sqlalchemy.orm import Session
  9. from configs import dify_config
  10. from constants import HIDDEN_VALUE, UNKNOWN_VALUE
  11. from core.helper.position_helper import is_filtered
  12. from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
  13. from core.plugin.entities.plugin import ToolProviderID
  14. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  15. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  16. from core.tools.entities.api_entities import (
  17. ToolApiEntity,
  18. ToolProviderApiEntity,
  19. ToolProviderCredentialApiEntity,
  20. ToolProviderCredentialInfoApiEntity,
  21. )
  22. from core.tools.entities.tool_entities import CredentialType
  23. from core.tools.errors import ToolProviderNotFoundError
  24. from core.tools.plugin_tool.provider import PluginToolProviderController
  25. from core.tools.tool_label_manager import ToolLabelManager
  26. from core.tools.tool_manager import ToolManager
  27. from core.tools.utils.encryption import create_provider_encrypter
  28. from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
  29. from extensions.ext_database import db
  30. from extensions.ext_redis import redis_client
  31. from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
  32. from services.plugin.plugin_service import PluginService
  33. from services.tools.tools_transform_service import ToolTransformService
  34. logger = logging.getLogger(__name__)
  35. class BuiltinToolManageService:
  36. __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
  37. __DEFAULT_EXPIRES_AT__ = 2147483647
  38. @staticmethod
  39. def delete_custom_oauth_client_params(tenant_id: str, provider: str):
  40. """
  41. delete custom oauth client params
  42. """
  43. tool_provider = ToolProviderID(provider)
  44. with Session(db.engine) as session:
  45. session.query(ToolOAuthTenantClient).filter_by(
  46. tenant_id=tenant_id,
  47. provider=tool_provider.provider_name,
  48. plugin_id=tool_provider.plugin_id,
  49. ).delete()
  50. session.commit()
  51. return {"result": "success"}
  52. @staticmethod
  53. def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str):
  54. """
  55. get builtin tool provider oauth client schema
  56. """
  57. provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
  58. verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
  59. tenant_id, provider.plugin_unique_identifier
  60. )
  61. is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled(
  62. tenant_id, provider_name
  63. )
  64. is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists(
  65. provider_name
  66. )
  67. result = {
  68. "schema": provider.get_oauth_client_schema(),
  69. "is_oauth_custom_client_enabled": is_oauth_custom_client_enabled,
  70. "is_system_oauth_params_exists": is_system_oauth_params_exists,
  71. "client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name),
  72. "redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback",
  73. }
  74. return result
  75. @staticmethod
  76. def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
  77. """
  78. list builtin tool provider tools
  79. :param tenant_id: the id of the tenant
  80. :param provider: the name of the provider
  81. :return: the list of tools
  82. """
  83. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  84. tools = provider_controller.get_tools()
  85. result: list[ToolApiEntity] = []
  86. for tool in tools or []:
  87. result.append(
  88. ToolTransformService.convert_tool_entity_to_api_entity(
  89. tool=tool,
  90. tenant_id=tenant_id,
  91. labels=ToolLabelManager.get_tool_labels(provider_controller),
  92. )
  93. )
  94. return result
  95. @staticmethod
  96. def get_builtin_tool_provider_info(tenant_id: str, provider: str):
  97. """
  98. get builtin tool provider info
  99. """
  100. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  101. # check if user has added the provider
  102. builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
  103. if builtin_provider is None:
  104. raise ValueError(f"you have not added provider {provider}")
  105. entity = ToolTransformService.builtin_provider_to_user_provider(
  106. provider_controller=provider_controller,
  107. db_provider=builtin_provider,
  108. decrypt_credentials=True,
  109. )
  110. entity.original_credentials = {}
  111. return entity
  112. @staticmethod
  113. def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
  114. """
  115. list builtin provider credentials schema
  116. :param credential_type: credential type
  117. :param provider_name: the name of the provider
  118. :param tenant_id: the id of the tenant
  119. :return: the list of tool providers
  120. """
  121. provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
  122. return provider.get_credentials_schema_by_type(credential_type)
  123. @staticmethod
  124. def update_builtin_tool_provider(
  125. user_id: str,
  126. tenant_id: str,
  127. provider: str,
  128. credential_id: str,
  129. credentials: dict | None = None,
  130. name: str | None = None,
  131. ):
  132. """
  133. update builtin tool provider
  134. """
  135. with Session(db.engine) as session:
  136. # get if the provider exists
  137. db_provider = (
  138. session.query(BuiltinToolProvider)
  139. .where(
  140. BuiltinToolProvider.tenant_id == tenant_id,
  141. BuiltinToolProvider.id == credential_id,
  142. )
  143. .first()
  144. )
  145. if db_provider is None:
  146. raise ValueError(f"you have not added provider {provider}")
  147. try:
  148. if CredentialType.of(db_provider.credential_type).is_editable() and credentials:
  149. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  150. if not provider_controller.need_credentials:
  151. raise ValueError(f"provider {provider} does not need credentials")
  152. encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
  153. tenant_id, db_provider, provider, provider_controller
  154. )
  155. original_credentials = encrypter.decrypt(db_provider.credentials)
  156. new_credentials: dict = {
  157. key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
  158. for key, value in credentials.items()
  159. }
  160. if CredentialType.of(db_provider.credential_type).is_validate_allowed():
  161. provider_controller.validate_credentials(user_id, new_credentials)
  162. # encrypt credentials
  163. db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
  164. cache.delete()
  165. # update name if provided
  166. if name and name != db_provider.name:
  167. # check if the name is already used
  168. if session.scalar(
  169. select(
  170. exists().where(
  171. BuiltinToolProvider.tenant_id == tenant_id,
  172. BuiltinToolProvider.provider == provider,
  173. BuiltinToolProvider.name == name,
  174. )
  175. )
  176. ):
  177. raise ValueError(f"the credential name '{name}' is already used")
  178. db_provider.name = name
  179. session.commit()
  180. except Exception as e:
  181. session.rollback()
  182. raise ValueError(str(e))
  183. return {"result": "success"}
  184. @staticmethod
  185. def add_builtin_tool_provider(
  186. user_id: str,
  187. api_type: CredentialType,
  188. tenant_id: str,
  189. provider: str,
  190. credentials: dict,
  191. expires_at: int = -1,
  192. name: str | None = None,
  193. ):
  194. """
  195. add builtin tool provider
  196. """
  197. with Session(db.engine) as session:
  198. try:
  199. lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
  200. with redis_client.lock(lock, timeout=20):
  201. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  202. if not provider_controller.need_credentials:
  203. raise ValueError(f"provider {provider} does not need credentials")
  204. provider_count = (
  205. session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
  206. )
  207. # check if the provider count is reached the limit
  208. if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
  209. raise ValueError(f"you have reached the maximum number of providers for {provider}")
  210. # validate credentials if allowed
  211. if CredentialType.of(api_type).is_validate_allowed():
  212. provider_controller.validate_credentials(user_id, credentials)
  213. # generate name if not provided
  214. if name is None or name == "":
  215. name = BuiltinToolManageService.generate_builtin_tool_provider_name(
  216. session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
  217. )
  218. else:
  219. # check if the name is already used
  220. if session.scalar(
  221. select(
  222. exists().where(
  223. BuiltinToolProvider.tenant_id == tenant_id,
  224. BuiltinToolProvider.provider == provider,
  225. BuiltinToolProvider.name == name,
  226. )
  227. )
  228. ):
  229. raise ValueError(f"the credential name '{name}' is already used")
  230. # create encrypter
  231. encrypter, _ = create_provider_encrypter(
  232. tenant_id=tenant_id,
  233. config=[
  234. x.to_basic_provider_config()
  235. for x in provider_controller.get_credentials_schema_by_type(api_type)
  236. ],
  237. cache=NoOpProviderCredentialCache(),
  238. )
  239. db_provider = BuiltinToolProvider(
  240. tenant_id=tenant_id,
  241. user_id=user_id,
  242. provider=provider,
  243. encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
  244. credential_type=api_type.value,
  245. name=name,
  246. expires_at=expires_at
  247. if expires_at is not None
  248. else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
  249. )
  250. session.add(db_provider)
  251. session.commit()
  252. except Exception as e:
  253. session.rollback()
  254. raise ValueError(str(e))
  255. return {"result": "success"}
  256. @staticmethod
  257. def create_tool_encrypter(
  258. tenant_id: str,
  259. db_provider: BuiltinToolProvider,
  260. provider: str,
  261. provider_controller: BuiltinToolProviderController,
  262. ):
  263. encrypter, cache = create_provider_encrypter(
  264. tenant_id=tenant_id,
  265. config=[
  266. x.to_basic_provider_config()
  267. for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
  268. ],
  269. cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
  270. )
  271. return encrypter, cache
  272. @staticmethod
  273. def generate_builtin_tool_provider_name(
  274. session: Session, tenant_id: str, provider: str, credential_type: CredentialType
  275. ) -> str:
  276. try:
  277. db_providers = (
  278. session.query(BuiltinToolProvider)
  279. .filter_by(
  280. tenant_id=tenant_id,
  281. provider=provider,
  282. credential_type=credential_type.value,
  283. )
  284. .order_by(BuiltinToolProvider.created_at.desc())
  285. .all()
  286. )
  287. # Get the default name pattern
  288. default_pattern = f"{credential_type.get_name()}"
  289. # Find all names that match the default pattern: "{default_pattern} {number}"
  290. pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
  291. numbers = []
  292. for db_provider in db_providers:
  293. if db_provider.name:
  294. match = re.match(pattern, db_provider.name.strip())
  295. if match:
  296. numbers.append(int(match.group(1)))
  297. # If no default pattern names found, start with 1
  298. if not numbers:
  299. return f"{default_pattern} 1"
  300. # Find the next number
  301. max_number = max(numbers)
  302. return f"{default_pattern} {max_number + 1}"
  303. except Exception as e:
  304. logger.warning("Error generating next provider name for %s: %s", provider, str(e))
  305. # fallback
  306. return f"{credential_type.get_name()} 1"
  307. @staticmethod
  308. def get_builtin_tool_provider_credentials(
  309. tenant_id: str, provider_name: str
  310. ) -> list[ToolProviderCredentialApiEntity]:
  311. """
  312. get builtin tool provider credentials
  313. """
  314. with db.session.no_autoflush:
  315. providers = (
  316. db.session.query(BuiltinToolProvider)
  317. .filter_by(tenant_id=tenant_id, provider=provider_name)
  318. .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
  319. .all()
  320. )
  321. if len(providers) == 0:
  322. return []
  323. default_provider = providers[0]
  324. default_provider.is_default = True
  325. provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
  326. credentials: list[ToolProviderCredentialApiEntity] = []
  327. encrypters = {}
  328. for provider in providers:
  329. credential_type = provider.credential_type
  330. if credential_type not in encrypters:
  331. encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter(
  332. tenant_id, provider, provider.provider, provider_controller
  333. )[0]
  334. encrypter = encrypters[credential_type]
  335. decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
  336. credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
  337. provider=provider,
  338. credentials=decrypt_credential,
  339. )
  340. credentials.append(credential_entity)
  341. return credentials
  342. @staticmethod
  343. def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity:
  344. """
  345. get builtin tool provider credential info
  346. """
  347. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  348. supported_credential_types = provider_controller.get_supported_credential_types()
  349. credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
  350. credential_info = ToolProviderCredentialInfoApiEntity(
  351. supported_credential_types=supported_credential_types,
  352. is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
  353. credentials=credentials,
  354. )
  355. return credential_info
  356. @staticmethod
  357. def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
  358. """
  359. delete tool provider
  360. """
  361. with Session(db.engine) as session:
  362. db_provider = (
  363. session.query(BuiltinToolProvider)
  364. .where(
  365. BuiltinToolProvider.tenant_id == tenant_id,
  366. BuiltinToolProvider.id == credential_id,
  367. )
  368. .first()
  369. )
  370. if db_provider is None:
  371. raise ValueError(f"you have not added provider {provider}")
  372. session.delete(db_provider)
  373. session.commit()
  374. # delete cache
  375. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  376. _, cache = BuiltinToolManageService.create_tool_encrypter(
  377. tenant_id, db_provider, provider, provider_controller
  378. )
  379. cache.delete()
  380. return {"result": "success"}
  381. @staticmethod
  382. def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
  383. """
  384. set default provider
  385. """
  386. with Session(db.engine) as session:
  387. # get provider
  388. target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
  389. if target_provider is None:
  390. raise ValueError("provider not found")
  391. # clear default provider
  392. session.query(BuiltinToolProvider).filter_by(
  393. tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
  394. ).update({"is_default": False})
  395. # set new default provider
  396. target_provider.is_default = True
  397. session.commit()
  398. return {"result": "success"}
  399. @staticmethod
  400. def is_oauth_system_client_exists(provider_name: str) -> bool:
  401. """
  402. check if oauth system client exists
  403. """
  404. tool_provider = ToolProviderID(provider_name)
  405. with Session(db.engine, autoflush=False) as session:
  406. system_client: ToolOAuthSystemClient | None = (
  407. session.query(ToolOAuthSystemClient)
  408. .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
  409. .first()
  410. )
  411. return system_client is not None
  412. @staticmethod
  413. def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
  414. """
  415. check if oauth custom client is enabled
  416. """
  417. tool_provider = ToolProviderID(provider)
  418. with Session(db.engine, autoflush=False) as session:
  419. user_client: ToolOAuthTenantClient | None = (
  420. session.query(ToolOAuthTenantClient)
  421. .filter_by(
  422. tenant_id=tenant_id,
  423. provider=tool_provider.provider_name,
  424. plugin_id=tool_provider.plugin_id,
  425. enabled=True,
  426. )
  427. .first()
  428. )
  429. return user_client is not None and user_client.enabled
  430. @staticmethod
  431. def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None:
  432. """
  433. get builtin tool provider
  434. """
  435. tool_provider = ToolProviderID(provider)
  436. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  437. encrypter, _ = create_provider_encrypter(
  438. tenant_id=tenant_id,
  439. config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
  440. cache=NoOpProviderCredentialCache(),
  441. )
  442. with Session(db.engine, autoflush=False) as session:
  443. user_client: ToolOAuthTenantClient | None = (
  444. session.query(ToolOAuthTenantClient)
  445. .filter_by(
  446. tenant_id=tenant_id,
  447. provider=tool_provider.provider_name,
  448. plugin_id=tool_provider.plugin_id,
  449. enabled=True,
  450. )
  451. .first()
  452. )
  453. oauth_params: Mapping[str, Any] | None = None
  454. if user_client:
  455. oauth_params = encrypter.decrypt(user_client.oauth_params)
  456. return oauth_params
  457. # only verified provider can use official oauth client
  458. is_verified = not isinstance(
  459. provider_controller, PluginToolProviderController
  460. ) or PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
  461. if not is_verified:
  462. return oauth_params
  463. system_client: ToolOAuthSystemClient | None = (
  464. session.query(ToolOAuthSystemClient)
  465. .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
  466. .first()
  467. )
  468. if system_client:
  469. try:
  470. oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
  471. except Exception as e:
  472. raise ValueError(f"Error decrypting system oauth params: {e}")
  473. return oauth_params
  474. @staticmethod
  475. def get_builtin_tool_provider_icon(provider: str):
  476. """
  477. get tool provider icon and it's mimetype
  478. """
  479. icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
  480. icon_bytes = Path(icon_path).read_bytes()
  481. return icon_bytes, mime_type
  482. @staticmethod
  483. def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
  484. """
  485. list builtin tools
  486. """
  487. # get all builtin providers
  488. provider_controllers = ToolManager.list_builtin_providers(tenant_id)
  489. # get all user added providers
  490. db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
  491. # rewrite db_providers
  492. for db_provider in db_providers:
  493. db_provider.provider = str(ToolProviderID(db_provider.provider))
  494. # find provider
  495. def find_provider(provider):
  496. return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
  497. result: list[ToolProviderApiEntity] = []
  498. for provider_controller in provider_controllers:
  499. try:
  500. # handle include, exclude
  501. if is_filtered(
  502. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  503. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  504. data=provider_controller,
  505. name_func=lambda x: x.entity.identity.name,
  506. ):
  507. continue
  508. # convert provider controller to user provider
  509. user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
  510. provider_controller=provider_controller,
  511. db_provider=find_provider(provider_controller.entity.identity.name),
  512. decrypt_credentials=True,
  513. )
  514. # add icon
  515. ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
  516. tools = provider_controller.get_tools()
  517. for tool in tools or []:
  518. user_builtin_provider.tools.append(
  519. ToolTransformService.convert_tool_entity_to_api_entity(
  520. tenant_id=tenant_id,
  521. tool=tool,
  522. labels=ToolLabelManager.get_tool_labels(provider_controller),
  523. )
  524. )
  525. result.append(user_builtin_provider)
  526. except Exception as e:
  527. raise e
  528. return BuiltinToolProviderSort.sort(result)
  529. @staticmethod
  530. def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
  531. """
  532. This method is used to fetch the builtin provider from the database
  533. 1.if the default provider exists, return the default provider
  534. 2.if the default provider does not exist, return the oldest provider
  535. """
  536. with Session(db.engine, autoflush=False) as session:
  537. try:
  538. full_provider_name = provider_name
  539. provider_id_entity = ToolProviderID(provider_name)
  540. provider_name = provider_id_entity.provider_name
  541. if provider_id_entity.organization != "langgenius":
  542. provider = (
  543. session.query(BuiltinToolProvider)
  544. .where(
  545. BuiltinToolProvider.tenant_id == tenant_id,
  546. BuiltinToolProvider.provider == full_provider_name,
  547. )
  548. .order_by(
  549. BuiltinToolProvider.is_default.desc(), # default=True first
  550. BuiltinToolProvider.created_at.asc(), # oldest first
  551. )
  552. .first()
  553. )
  554. else:
  555. provider = (
  556. session.query(BuiltinToolProvider)
  557. .where(
  558. BuiltinToolProvider.tenant_id == tenant_id,
  559. (BuiltinToolProvider.provider == provider_name)
  560. | (BuiltinToolProvider.provider == full_provider_name),
  561. )
  562. .order_by(
  563. BuiltinToolProvider.is_default.desc(), # default=True first
  564. BuiltinToolProvider.created_at.asc(), # oldest first
  565. )
  566. .first()
  567. )
  568. if provider is None:
  569. return None
  570. provider.provider = ToolProviderID(provider.provider).to_string()
  571. return provider
  572. except Exception:
  573. # it's an old provider without organization
  574. return (
  575. session.query(BuiltinToolProvider)
  576. .where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
  577. .order_by(
  578. BuiltinToolProvider.is_default.desc(), # default=True first
  579. BuiltinToolProvider.created_at.asc(), # oldest first
  580. )
  581. .first()
  582. )
  583. @staticmethod
  584. def save_custom_oauth_client_params(
  585. tenant_id: str,
  586. provider: str,
  587. client_params: Optional[dict] = None,
  588. enable_oauth_custom_client: Optional[bool] = None,
  589. ):
  590. """
  591. setup oauth custom client
  592. """
  593. if client_params is None and enable_oauth_custom_client is None:
  594. return {"result": "success"}
  595. tool_provider = ToolProviderID(provider)
  596. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  597. if not provider_controller:
  598. raise ToolProviderNotFoundError(f"Provider {provider} not found")
  599. if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
  600. raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
  601. with Session(db.engine) as session:
  602. custom_client_params = (
  603. session.query(ToolOAuthTenantClient)
  604. .filter_by(
  605. tenant_id=tenant_id,
  606. plugin_id=tool_provider.plugin_id,
  607. provider=tool_provider.provider_name,
  608. )
  609. .first()
  610. )
  611. # if the record does not exist, create a basic record
  612. if custom_client_params is None:
  613. custom_client_params = ToolOAuthTenantClient(
  614. tenant_id=tenant_id,
  615. plugin_id=tool_provider.plugin_id,
  616. provider=tool_provider.provider_name,
  617. )
  618. session.add(custom_client_params)
  619. if client_params is not None:
  620. encrypter, _ = create_provider_encrypter(
  621. tenant_id=tenant_id,
  622. config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
  623. cache=NoOpProviderCredentialCache(),
  624. )
  625. original_params = encrypter.decrypt(custom_client_params.oauth_params)
  626. new_params: dict = {
  627. key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
  628. for key, value in client_params.items()
  629. }
  630. custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
  631. if enable_oauth_custom_client is not None:
  632. custom_client_params.enabled = enable_oauth_custom_client
  633. session.commit()
  634. return {"result": "success"}
  635. @staticmethod
  636. def get_custom_oauth_client_params(tenant_id: str, provider: str):
  637. """
  638. get custom oauth client params
  639. """
  640. with Session(db.engine) as session:
  641. tool_provider = ToolProviderID(provider)
  642. custom_oauth_client_params: ToolOAuthTenantClient | None = (
  643. session.query(ToolOAuthTenantClient)
  644. .filter_by(
  645. tenant_id=tenant_id,
  646. plugin_id=tool_provider.plugin_id,
  647. provider=tool_provider.provider_name,
  648. )
  649. .first()
  650. )
  651. if custom_oauth_client_params is None:
  652. return {}
  653. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  654. if not provider_controller:
  655. raise ToolProviderNotFoundError(f"Provider {provider} not found")
  656. if not isinstance(provider_controller, BuiltinToolProviderController):
  657. raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
  658. encrypter, _ = create_provider_encrypter(
  659. tenant_id=tenant_id,
  660. config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
  661. cache=NoOpProviderCredentialCache(),
  662. )
  663. return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))