You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasource_provider_service.py 42KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975
  1. import logging
  2. import time
  3. from collections.abc import Mapping
  4. from typing import Any
  5. from flask_login import current_user
  6. from sqlalchemy.orm import Session
  7. from configs import dify_config
  8. from constants import HIDDEN_VALUE, UNKNOWN_VALUE
  9. from core.helper import encrypter
  10. from core.helper.name_generator import generate_incremental_name
  11. from core.helper.provider_cache import NoOpProviderCredentialCache
  12. from core.model_runtime.entities.provider_entities import FormType
  13. from core.plugin.impl.datasource import PluginDatasourceManager
  14. from core.plugin.impl.oauth import OAuthHandler
  15. from core.tools.entities.tool_entities import CredentialType
  16. from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
  17. from extensions.ext_database import db
  18. from extensions.ext_redis import redis_client
  19. from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
  20. from models.provider_ids import DatasourceProviderID
  21. from services.plugin.plugin_service import PluginService
  22. logger = logging.getLogger(__name__)
  23. class DatasourceProviderService:
  24. """
  25. Model Provider Service
  26. """
  27. def __init__(self) -> None:
  28. self.provider_manager = PluginDatasourceManager()
  29. def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
  30. """
  31. remove oauth custom client params
  32. """
  33. with Session(db.engine) as session:
  34. session.query(DatasourceOauthTenantParamConfig).filter_by(
  35. tenant_id=tenant_id,
  36. provider=datasource_provider_id.provider_name,
  37. plugin_id=datasource_provider_id.plugin_id,
  38. ).delete()
  39. session.commit()
  40. def decrypt_datasource_provider_credentials(
  41. self,
  42. tenant_id: str,
  43. datasource_provider: DatasourceProvider,
  44. plugin_id: str,
  45. provider: str,
  46. ) -> dict[str, Any]:
  47. encrypted_credentials = datasource_provider.encrypted_credentials
  48. credential_secret_variables = self.extract_secret_variables(
  49. tenant_id=tenant_id,
  50. provider_id=f"{plugin_id}/{provider}",
  51. credential_type=CredentialType.of(datasource_provider.auth_type),
  52. )
  53. decrypted_credentials = encrypted_credentials.copy()
  54. for key, value in decrypted_credentials.items():
  55. if key in credential_secret_variables:
  56. decrypted_credentials[key] = encrypter.decrypt_token(tenant_id, value)
  57. return decrypted_credentials
  58. def encrypt_datasource_provider_credentials(
  59. self,
  60. tenant_id: str,
  61. provider: str,
  62. plugin_id: str,
  63. raw_credentials: Mapping[str, Any],
  64. datasource_provider: DatasourceProvider,
  65. ) -> dict[str, Any]:
  66. provider_credential_secret_variables = self.extract_secret_variables(
  67. tenant_id=tenant_id,
  68. provider_id=f"{plugin_id}/{provider}",
  69. credential_type=CredentialType.of(datasource_provider.auth_type),
  70. )
  71. encrypted_credentials = dict(raw_credentials)
  72. for key, value in encrypted_credentials.items():
  73. if key in provider_credential_secret_variables:
  74. encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
  75. return encrypted_credentials
  76. def get_datasource_credentials(
  77. self,
  78. tenant_id: str,
  79. provider: str,
  80. plugin_id: str,
  81. credential_id: str | None = None,
  82. ) -> dict[str, Any]:
  83. """
  84. get credential by id
  85. """
  86. with Session(db.engine) as session:
  87. if credential_id:
  88. datasource_provider = (
  89. session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
  90. )
  91. else:
  92. datasource_provider = (
  93. session.query(DatasourceProvider)
  94. .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
  95. .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
  96. .first()
  97. )
  98. if not datasource_provider:
  99. return {}
  100. # refresh the credentials
  101. if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
  102. decrypted_credentials = self.decrypt_datasource_provider_credentials(
  103. tenant_id=tenant_id,
  104. datasource_provider=datasource_provider,
  105. plugin_id=plugin_id,
  106. provider=provider,
  107. )
  108. datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
  109. provider_name = datasource_provider_id.provider_name
  110. redirect_uri = (
  111. f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
  112. f"{datasource_provider_id}/datasource/callback"
  113. )
  114. system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
  115. refreshed_credentials = OAuthHandler().refresh_credentials(
  116. tenant_id=tenant_id,
  117. user_id=current_user.id,
  118. plugin_id=datasource_provider_id.plugin_id,
  119. provider=provider_name,
  120. redirect_uri=redirect_uri,
  121. system_credentials=system_credentials or {},
  122. credentials=decrypted_credentials,
  123. )
  124. datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
  125. tenant_id=tenant_id,
  126. raw_credentials=refreshed_credentials.credentials,
  127. provider=provider,
  128. plugin_id=plugin_id,
  129. datasource_provider=datasource_provider,
  130. )
  131. datasource_provider.expires_at = refreshed_credentials.expires_at
  132. session.commit()
  133. return self.decrypt_datasource_provider_credentials(
  134. tenant_id=tenant_id,
  135. datasource_provider=datasource_provider,
  136. plugin_id=plugin_id,
  137. provider=provider,
  138. )
  139. def get_all_datasource_credentials_by_provider(
  140. self,
  141. tenant_id: str,
  142. provider: str,
  143. plugin_id: str,
  144. ) -> list[dict[str, Any]]:
  145. """
  146. get all datasource credentials by provider
  147. """
  148. with Session(db.engine) as session:
  149. datasource_providers = (
  150. session.query(DatasourceProvider)
  151. .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
  152. .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
  153. .all()
  154. )
  155. if not datasource_providers:
  156. return []
  157. # refresh the credentials
  158. real_credentials_list = []
  159. for datasource_provider in datasource_providers:
  160. decrypted_credentials = self.decrypt_datasource_provider_credentials(
  161. tenant_id=tenant_id,
  162. datasource_provider=datasource_provider,
  163. plugin_id=plugin_id,
  164. provider=provider,
  165. )
  166. datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
  167. provider_name = datasource_provider_id.provider_name
  168. redirect_uri = (
  169. f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
  170. f"{datasource_provider_id}/datasource/callback"
  171. )
  172. system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
  173. refreshed_credentials = OAuthHandler().refresh_credentials(
  174. tenant_id=tenant_id,
  175. user_id=current_user.id,
  176. plugin_id=datasource_provider_id.plugin_id,
  177. provider=provider_name,
  178. redirect_uri=redirect_uri,
  179. system_credentials=system_credentials or {},
  180. credentials=decrypted_credentials,
  181. )
  182. datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
  183. tenant_id=tenant_id,
  184. raw_credentials=refreshed_credentials.credentials,
  185. provider=provider,
  186. plugin_id=plugin_id,
  187. datasource_provider=datasource_provider,
  188. )
  189. datasource_provider.expires_at = refreshed_credentials.expires_at
  190. real_credentials = self.decrypt_datasource_provider_credentials(
  191. tenant_id=tenant_id,
  192. datasource_provider=datasource_provider,
  193. plugin_id=plugin_id,
  194. provider=provider,
  195. )
  196. real_credentials_list.append(real_credentials)
  197. session.commit()
  198. return real_credentials_list
  199. def update_datasource_provider_name(
  200. self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
  201. ):
  202. """
  203. update datasource provider name
  204. """
  205. with Session(db.engine) as session:
  206. target_provider = (
  207. session.query(DatasourceProvider)
  208. .filter_by(
  209. tenant_id=tenant_id,
  210. id=credential_id,
  211. provider=datasource_provider_id.provider_name,
  212. plugin_id=datasource_provider_id.plugin_id,
  213. )
  214. .first()
  215. )
  216. if target_provider is None:
  217. raise ValueError("provider not found")
  218. if target_provider.name == name:
  219. return
  220. # check name is exist
  221. if (
  222. session.query(DatasourceProvider)
  223. .filter_by(
  224. tenant_id=tenant_id,
  225. name=name,
  226. provider=datasource_provider_id.provider_name,
  227. plugin_id=datasource_provider_id.plugin_id,
  228. )
  229. .count()
  230. > 0
  231. ):
  232. raise ValueError("Authorization name is already exists")
  233. target_provider.name = name
  234. session.commit()
  235. return
  236. def set_default_datasource_provider(
  237. self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str
  238. ):
  239. """
  240. set default datasource provider
  241. """
  242. with Session(db.engine) as session:
  243. # get provider
  244. target_provider = (
  245. session.query(DatasourceProvider)
  246. .filter_by(
  247. tenant_id=tenant_id,
  248. id=credential_id,
  249. provider=datasource_provider_id.provider_name,
  250. plugin_id=datasource_provider_id.plugin_id,
  251. )
  252. .first()
  253. )
  254. if target_provider is None:
  255. raise ValueError("provider not found")
  256. # clear default provider
  257. session.query(DatasourceProvider).filter_by(
  258. tenant_id=tenant_id,
  259. provider=target_provider.provider,
  260. plugin_id=target_provider.plugin_id,
  261. is_default=True,
  262. ).update({"is_default": False})
  263. # set new default provider
  264. target_provider.is_default = True
  265. session.commit()
  266. return {"result": "success"}
  267. def setup_oauth_custom_client_params(
  268. self,
  269. tenant_id: str,
  270. datasource_provider_id: DatasourceProviderID,
  271. client_params: dict | None,
  272. enabled: bool | None,
  273. ):
  274. """
  275. setup oauth custom client params
  276. """
  277. if client_params is None and enabled is None:
  278. return
  279. with Session(db.engine) as session:
  280. tenant_oauth_client_params = (
  281. session.query(DatasourceOauthTenantParamConfig)
  282. .filter_by(
  283. tenant_id=tenant_id,
  284. provider=datasource_provider_id.provider_name,
  285. plugin_id=datasource_provider_id.plugin_id,
  286. )
  287. .first()
  288. )
  289. if not tenant_oauth_client_params:
  290. tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
  291. tenant_id=tenant_id,
  292. provider=datasource_provider_id.provider_name,
  293. plugin_id=datasource_provider_id.plugin_id,
  294. client_params={},
  295. enabled=False,
  296. )
  297. session.add(tenant_oauth_client_params)
  298. if client_params is not None:
  299. encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
  300. original_params = (
  301. encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
  302. )
  303. new_params: dict = {
  304. key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
  305. for key, value in client_params.items()
  306. }
  307. tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
  308. if enabled is not None:
  309. tenant_oauth_client_params.enabled = enabled
  310. session.commit()
  311. def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
  312. """
  313. check if system oauth params exist
  314. """
  315. with Session(db.engine).no_autoflush as session:
  316. return (
  317. session.query(DatasourceOauthParamConfig)
  318. .filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
  319. .first()
  320. is not None
  321. )
  322. def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
  323. """
  324. check if tenant oauth params is enabled
  325. """
  326. return (
  327. db.session.query(DatasourceOauthTenantParamConfig)
  328. .filter_by(
  329. tenant_id=tenant_id,
  330. provider=datasource_provider_id.provider_name,
  331. plugin_id=datasource_provider_id.plugin_id,
  332. enabled=True,
  333. )
  334. .count()
  335. > 0
  336. )
  337. def get_tenant_oauth_client(
  338. self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
  339. ) -> dict[str, Any] | None:
  340. """
  341. get tenant oauth client
  342. """
  343. tenant_oauth_client_params = (
  344. db.session.query(DatasourceOauthTenantParamConfig)
  345. .filter_by(
  346. tenant_id=tenant_id,
  347. provider=datasource_provider_id.provider_name,
  348. plugin_id=datasource_provider_id.plugin_id,
  349. )
  350. .first()
  351. )
  352. if tenant_oauth_client_params:
  353. encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
  354. if mask:
  355. return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
  356. else:
  357. return encrypter.decrypt(tenant_oauth_client_params.client_params)
  358. return None
  359. def get_oauth_encrypter(
  360. self, tenant_id: str, datasource_provider_id: DatasourceProviderID
  361. ) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
  362. """
  363. get oauth encrypter
  364. """
  365. datasource_provider = self.provider_manager.fetch_datasource_provider(
  366. tenant_id=tenant_id, provider_id=str(datasource_provider_id)
  367. )
  368. if not datasource_provider.declaration.oauth_schema:
  369. raise ValueError("Datasource provider oauth schema not found")
  370. client_schema = datasource_provider.declaration.oauth_schema.client_schema
  371. return create_provider_encrypter(
  372. tenant_id=tenant_id,
  373. config=[x.to_basic_provider_config() for x in client_schema],
  374. cache=NoOpProviderCredentialCache(),
  375. )
  376. def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
  377. """
  378. get oauth client
  379. """
  380. provider = datasource_provider_id.provider_name
  381. plugin_id = datasource_provider_id.plugin_id
  382. with Session(db.engine).no_autoflush as session:
  383. # get tenant oauth client params
  384. tenant_oauth_client_params = (
  385. session.query(DatasourceOauthTenantParamConfig)
  386. .filter_by(
  387. tenant_id=tenant_id,
  388. provider=provider,
  389. plugin_id=plugin_id,
  390. enabled=True,
  391. )
  392. .first()
  393. )
  394. if tenant_oauth_client_params:
  395. encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
  396. return encrypter.decrypt(tenant_oauth_client_params.client_params)
  397. provider_controller = self.provider_manager.fetch_datasource_provider(
  398. tenant_id=tenant_id, provider_id=str(datasource_provider_id)
  399. )
  400. is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
  401. if is_verified:
  402. # fallback to system oauth client params
  403. oauth_client_params = (
  404. session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
  405. )
  406. if oauth_client_params:
  407. return oauth_client_params.system_credentials
  408. raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
  409. @staticmethod
  410. def generate_next_datasource_provider_name(
  411. session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
  412. ) -> str:
  413. db_providers = (
  414. session.query(DatasourceProvider)
  415. .filter_by(
  416. tenant_id=tenant_id,
  417. provider=provider_id.provider_name,
  418. plugin_id=provider_id.plugin_id,
  419. )
  420. .all()
  421. )
  422. return generate_incremental_name(
  423. [provider.name for provider in db_providers],
  424. f"{credential_type.get_name()}",
  425. )
  426. def reauthorize_datasource_oauth_provider(
  427. self,
  428. name: str | None,
  429. tenant_id: str,
  430. provider_id: DatasourceProviderID,
  431. avatar_url: str | None,
  432. expire_at: int,
  433. credentials: dict,
  434. credential_id: str,
  435. ) -> None:
  436. """
  437. update datasource oauth provider
  438. """
  439. with Session(db.engine) as session:
  440. lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
  441. with redis_client.lock(lock, timeout=20):
  442. target_provider = (
  443. session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
  444. )
  445. if target_provider is None:
  446. raise ValueError("provider not found")
  447. db_provider_name = name
  448. if not db_provider_name:
  449. db_provider_name = target_provider.name
  450. else:
  451. name_conflict = (
  452. session.query(DatasourceProvider)
  453. .filter_by(
  454. tenant_id=tenant_id,
  455. name=db_provider_name,
  456. provider=provider_id.provider_name,
  457. plugin_id=provider_id.plugin_id,
  458. auth_type=CredentialType.OAUTH2.value,
  459. )
  460. .count()
  461. )
  462. if name_conflict > 0:
  463. db_provider_name = generate_incremental_name(
  464. [
  465. provider.name
  466. for provider in session.query(DatasourceProvider).filter_by(
  467. tenant_id=tenant_id,
  468. provider=provider_id.provider_name,
  469. plugin_id=provider_id.plugin_id,
  470. )
  471. ],
  472. db_provider_name,
  473. )
  474. provider_credential_secret_variables = self.extract_secret_variables(
  475. tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2
  476. )
  477. for key, value in credentials.items():
  478. if key in provider_credential_secret_variables:
  479. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  480. target_provider.expires_at = expire_at
  481. target_provider.encrypted_credentials = credentials
  482. target_provider.avatar_url = avatar_url or target_provider.avatar_url
  483. session.commit()
  484. def add_datasource_oauth_provider(
  485. self,
  486. name: str | None,
  487. tenant_id: str,
  488. provider_id: DatasourceProviderID,
  489. avatar_url: str | None,
  490. expire_at: int,
  491. credentials: dict,
  492. ) -> None:
  493. """
  494. add datasource oauth provider
  495. """
  496. credential_type = CredentialType.OAUTH2
  497. with Session(db.engine) as session:
  498. lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
  499. with redis_client.lock(lock, timeout=60):
  500. db_provider_name = name
  501. if not db_provider_name:
  502. db_provider_name = self.generate_next_datasource_provider_name(
  503. session=session,
  504. tenant_id=tenant_id,
  505. provider_id=provider_id,
  506. credential_type=credential_type,
  507. )
  508. else:
  509. if (
  510. session.query(DatasourceProvider)
  511. .filter_by(
  512. tenant_id=tenant_id,
  513. name=db_provider_name,
  514. provider=provider_id.provider_name,
  515. plugin_id=provider_id.plugin_id,
  516. auth_type=credential_type.value,
  517. )
  518. .count()
  519. > 0
  520. ):
  521. db_provider_name = generate_incremental_name(
  522. [
  523. provider.name
  524. for provider in session.query(DatasourceProvider).filter_by(
  525. tenant_id=tenant_id,
  526. provider=provider_id.provider_name,
  527. plugin_id=provider_id.plugin_id,
  528. )
  529. ],
  530. db_provider_name,
  531. )
  532. provider_credential_secret_variables = self.extract_secret_variables(
  533. tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type
  534. )
  535. for key, value in credentials.items():
  536. if key in provider_credential_secret_variables:
  537. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  538. datasource_provider = DatasourceProvider(
  539. tenant_id=tenant_id,
  540. name=db_provider_name,
  541. provider=provider_id.provider_name,
  542. plugin_id=provider_id.plugin_id,
  543. auth_type=credential_type.value,
  544. encrypted_credentials=credentials,
  545. avatar_url=avatar_url or "default",
  546. expires_at=expire_at,
  547. )
  548. session.add(datasource_provider)
  549. session.commit()
  550. def add_datasource_api_key_provider(
  551. self,
  552. name: str | None,
  553. tenant_id: str,
  554. provider_id: DatasourceProviderID,
  555. credentials: dict,
  556. ) -> None:
  557. """
  558. validate datasource provider credentials.
  559. :param tenant_id:
  560. :param provider:
  561. :param credentials:
  562. """
  563. provider_name = provider_id.provider_name
  564. plugin_id = provider_id.plugin_id
  565. with Session(db.engine) as session:
  566. lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
  567. with redis_client.lock(lock, timeout=20):
  568. db_provider_name = name or self.generate_next_datasource_provider_name(
  569. session=session,
  570. tenant_id=tenant_id,
  571. provider_id=provider_id,
  572. credential_type=CredentialType.API_KEY,
  573. )
  574. # check name is exist
  575. if (
  576. session.query(DatasourceProvider)
  577. .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
  578. .count()
  579. > 0
  580. ):
  581. raise ValueError("Authorization name is already exists")
  582. try:
  583. self.provider_manager.validate_provider_credentials(
  584. tenant_id=tenant_id,
  585. user_id=current_user.id,
  586. provider=provider_name,
  587. plugin_id=plugin_id,
  588. credentials=credentials,
  589. )
  590. except Exception as e:
  591. raise ValueError(f"Failed to validate credentials: {str(e)}")
  592. provider_credential_secret_variables = self.extract_secret_variables(
  593. tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY
  594. )
  595. for key, value in credentials.items():
  596. if key in provider_credential_secret_variables:
  597. # if send [__HIDDEN__] in secret input, it will be same as original value
  598. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  599. datasource_provider = DatasourceProvider(
  600. tenant_id=tenant_id,
  601. name=db_provider_name,
  602. provider=provider_name,
  603. plugin_id=plugin_id,
  604. auth_type=CredentialType.API_KEY.value,
  605. encrypted_credentials=credentials,
  606. )
  607. session.add(datasource_provider)
  608. session.commit()
  609. def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
  610. """
  611. Extract secret input form variables.
  612. :param credential_form_schemas:
  613. :return:
  614. """
  615. datasource_provider = self.provider_manager.fetch_datasource_provider(
  616. tenant_id=tenant_id, provider_id=provider_id
  617. )
  618. credential_form_schemas = []
  619. if credential_type == CredentialType.API_KEY:
  620. credential_form_schemas = list(datasource_provider.declaration.credentials_schema)
  621. elif credential_type == CredentialType.OAUTH2:
  622. if not datasource_provider.declaration.oauth_schema:
  623. raise ValueError("Datasource provider oauth schema not found")
  624. credential_form_schemas = list(datasource_provider.declaration.oauth_schema.credentials_schema)
  625. else:
  626. raise ValueError(f"Invalid credential type: {credential_type}")
  627. secret_input_form_variables = []
  628. for credential_form_schema in credential_form_schemas:
  629. if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
  630. secret_input_form_variables.append(credential_form_schema.name)
  631. return secret_input_form_variables
  632. def list_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  633. """
  634. list datasource credentials with obfuscated sensitive fields.
  635. :param tenant_id: workspace id
  636. :param provider_id: provider id
  637. :return:
  638. """
  639. # Get all provider configurations of the current workspace
  640. datasource_providers: list[DatasourceProvider] = (
  641. db.session.query(DatasourceProvider)
  642. .where(
  643. DatasourceProvider.tenant_id == tenant_id,
  644. DatasourceProvider.provider == provider,
  645. DatasourceProvider.plugin_id == plugin_id,
  646. )
  647. .all()
  648. )
  649. if not datasource_providers:
  650. return []
  651. copy_credentials_list = []
  652. default_provider = (
  653. db.session.query(DatasourceProvider.id)
  654. .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
  655. .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
  656. .first()
  657. )
  658. default_provider_id = default_provider.id if default_provider else None
  659. for datasource_provider in datasource_providers:
  660. encrypted_credentials = datasource_provider.encrypted_credentials
  661. # Get provider credential secret variables
  662. credential_secret_variables = self.extract_secret_variables(
  663. tenant_id=tenant_id,
  664. provider_id=f"{plugin_id}/{provider}",
  665. credential_type=CredentialType.of(datasource_provider.auth_type),
  666. )
  667. # Obfuscate provider credentials
  668. copy_credentials = encrypted_credentials.copy()
  669. for key, value in copy_credentials.items():
  670. if key in credential_secret_variables:
  671. copy_credentials[key] = encrypter.obfuscated_token(value)
  672. copy_credentials_list.append(
  673. {
  674. "credential": copy_credentials,
  675. "type": datasource_provider.auth_type,
  676. "name": datasource_provider.name,
  677. "avatar_url": datasource_provider.avatar_url,
  678. "id": datasource_provider.id,
  679. "is_default": default_provider_id and datasource_provider.id == default_provider_id,
  680. }
  681. )
  682. return copy_credentials_list
  683. def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
  684. """
  685. get datasource credentials.
  686. :return:
  687. """
  688. # get all plugin providers
  689. manager = PluginDatasourceManager()
  690. datasources = manager.fetch_installed_datasource_providers(tenant_id)
  691. datasource_credentials = []
  692. for datasource in datasources:
  693. datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
  694. credentials = self.list_datasource_credentials(
  695. tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
  696. )
  697. redirect_uri = (
  698. f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
  699. )
  700. datasource_credentials.append(
  701. {
  702. "provider": datasource.provider,
  703. "plugin_id": datasource.plugin_id,
  704. "plugin_unique_identifier": datasource.plugin_unique_identifier,
  705. "icon": datasource.declaration.identity.icon,
  706. "name": datasource.declaration.identity.name.split("/")[-1],
  707. "label": datasource.declaration.identity.label.model_dump(),
  708. "description": datasource.declaration.identity.description.model_dump(),
  709. "author": datasource.declaration.identity.author,
  710. "credentials_list": credentials,
  711. "credential_schema": [
  712. credential.model_dump() for credential in datasource.declaration.credentials_schema
  713. ],
  714. "oauth_schema": {
  715. "client_schema": [
  716. client_schema.model_dump()
  717. for client_schema in datasource.declaration.oauth_schema.client_schema
  718. ],
  719. "credentials_schema": [
  720. credential_schema.model_dump()
  721. for credential_schema in datasource.declaration.oauth_schema.credentials_schema
  722. ],
  723. "oauth_custom_client_params": self.get_tenant_oauth_client(
  724. tenant_id, datasource_provider_id, mask=True
  725. ),
  726. "is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
  727. tenant_id, datasource_provider_id
  728. ),
  729. "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
  730. "redirect_uri": redirect_uri,
  731. }
  732. if datasource.declaration.oauth_schema
  733. else None,
  734. }
  735. )
  736. return datasource_credentials
  737. def get_hard_code_datasource_credentials(self, tenant_id: str) -> list[dict]:
  738. """
  739. get hard code datasource credentials.
  740. :return:
  741. """
  742. # get all plugin providers
  743. manager = PluginDatasourceManager()
  744. datasources = manager.fetch_installed_datasource_providers(tenant_id)
  745. datasource_credentials = []
  746. for datasource in datasources:
  747. if datasource.plugin_id in [
  748. "langgenius/firecrawl_datasource",
  749. "langgenius/notion_datasource",
  750. "langgenius/jina_datasource",
  751. ]:
  752. datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
  753. credentials = self.list_datasource_credentials(
  754. tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
  755. )
  756. redirect_uri = "{}/console/api/oauth/plugin/{}/datasource/callback".format(
  757. dify_config.CONSOLE_API_URL, datasource_provider_id
  758. )
  759. datasource_credentials.append(
  760. {
  761. "provider": datasource.provider,
  762. "plugin_id": datasource.plugin_id,
  763. "plugin_unique_identifier": datasource.plugin_unique_identifier,
  764. "icon": datasource.declaration.identity.icon,
  765. "name": datasource.declaration.identity.name.split("/")[-1],
  766. "label": datasource.declaration.identity.label.model_dump(),
  767. "description": datasource.declaration.identity.description.model_dump(),
  768. "author": datasource.declaration.identity.author,
  769. "credentials_list": credentials,
  770. "credential_schema": [
  771. credential.model_dump() for credential in datasource.declaration.credentials_schema
  772. ],
  773. "oauth_schema": {
  774. "client_schema": [
  775. client_schema.model_dump()
  776. for client_schema in datasource.declaration.oauth_schema.client_schema
  777. ],
  778. "credentials_schema": [
  779. credential_schema.model_dump()
  780. for credential_schema in datasource.declaration.oauth_schema.credentials_schema
  781. ],
  782. "oauth_custom_client_params": self.get_tenant_oauth_client(
  783. tenant_id, datasource_provider_id, mask=True
  784. ),
  785. "is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
  786. tenant_id, datasource_provider_id
  787. ),
  788. "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
  789. "redirect_uri": redirect_uri,
  790. }
  791. if datasource.declaration.oauth_schema
  792. else None,
  793. }
  794. )
  795. return datasource_credentials
  796. def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  797. """
  798. get datasource credentials.
  799. :param tenant_id: workspace id
  800. :param provider_id: provider id
  801. :return:
  802. """
  803. # Get all provider configurations of the current workspace
  804. datasource_providers: list[DatasourceProvider] = (
  805. db.session.query(DatasourceProvider)
  806. .where(
  807. DatasourceProvider.tenant_id == tenant_id,
  808. DatasourceProvider.provider == provider,
  809. DatasourceProvider.plugin_id == plugin_id,
  810. )
  811. .all()
  812. )
  813. if not datasource_providers:
  814. return []
  815. copy_credentials_list = []
  816. for datasource_provider in datasource_providers:
  817. encrypted_credentials = datasource_provider.encrypted_credentials
  818. # Get provider credential secret variables
  819. credential_secret_variables = self.extract_secret_variables(
  820. tenant_id=tenant_id,
  821. provider_id=f"{plugin_id}/{provider}",
  822. credential_type=CredentialType.of(datasource_provider.auth_type),
  823. )
  824. # Obfuscate provider credentials
  825. copy_credentials = encrypted_credentials.copy()
  826. for key, value in copy_credentials.items():
  827. if key in credential_secret_variables:
  828. copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
  829. copy_credentials_list.append(
  830. {
  831. "credentials": copy_credentials,
  832. "type": datasource_provider.auth_type,
  833. }
  834. )
  835. return copy_credentials_list
  836. def update_datasource_credentials(
  837. self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
  838. ) -> None:
  839. """
  840. update datasource credentials.
  841. """
  842. with Session(db.engine) as session:
  843. datasource_provider = (
  844. session.query(DatasourceProvider)
  845. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  846. .first()
  847. )
  848. if not datasource_provider:
  849. raise ValueError("Datasource provider not found")
  850. # update name
  851. if name and name != datasource_provider.name:
  852. if (
  853. session.query(DatasourceProvider)
  854. .filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
  855. .count()
  856. > 0
  857. ):
  858. raise ValueError("Authorization name is already exists")
  859. datasource_provider.name = name
  860. # update credentials
  861. if credentials:
  862. secret_variables = self.extract_secret_variables(
  863. tenant_id=tenant_id,
  864. provider_id=f"{plugin_id}/{provider}",
  865. credential_type=CredentialType.of(datasource_provider.auth_type),
  866. )
  867. original_credentials = {
  868. key: value if key not in secret_variables else encrypter.decrypt_token(tenant_id, value)
  869. for key, value in datasource_provider.encrypted_credentials.items()
  870. }
  871. new_credentials = {
  872. key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
  873. for key, value in credentials.items()
  874. }
  875. try:
  876. self.provider_manager.validate_provider_credentials(
  877. tenant_id=tenant_id,
  878. user_id=current_user.id,
  879. provider=provider,
  880. plugin_id=plugin_id,
  881. credentials=new_credentials,
  882. )
  883. except Exception as e:
  884. raise ValueError(f"Failed to validate credentials: {str(e)}")
  885. encrypted_credentials = {}
  886. for key, value in new_credentials.items():
  887. if key in secret_variables:
  888. encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
  889. else:
  890. encrypted_credentials[key] = value
  891. datasource_provider.encrypted_credentials = encrypted_credentials
  892. session.commit()
  893. def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
  894. """
  895. remove datasource credentials.
  896. :param tenant_id: workspace id
  897. :param provider: provider name
  898. :param plugin_id: plugin id
  899. :return:
  900. """
  901. datasource_provider = (
  902. db.session.query(DatasourceProvider)
  903. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  904. .first()
  905. )
  906. if datasource_provider:
  907. db.session.delete(datasource_provider)
  908. db.session.commit()