您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

datasource_provider_service.py 42KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974
  1. import logging
  2. import time
  3. from typing import Any
  4. from flask_login import current_user
  5. from sqlalchemy.orm import Session
  6. from configs import dify_config
  7. from constants import HIDDEN_VALUE, UNKNOWN_VALUE
  8. from core.helper import encrypter
  9. from core.helper.name_generator import generate_incremental_name
  10. from core.helper.provider_cache import NoOpProviderCredentialCache
  11. from core.model_runtime.entities.provider_entities import FormType
  12. from core.plugin.impl.datasource import PluginDatasourceManager
  13. from core.plugin.impl.oauth import OAuthHandler
  14. from core.tools.entities.tool_entities import CredentialType
  15. from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
  16. from extensions.ext_database import db
  17. from extensions.ext_redis import redis_client
  18. from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
  19. from models.provider_ids import DatasourceProviderID
  20. from services.plugin.plugin_service import PluginService
  21. logger = logging.getLogger(__name__)
  22. class DatasourceProviderService:
  23. """
  24. Model Provider Service
  25. """
  26. def __init__(self) -> None:
  27. self.provider_manager = PluginDatasourceManager()
  28. def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
  29. """
  30. remove oauth custom client params
  31. """
  32. with Session(db.engine) as session:
  33. session.query(DatasourceOauthTenantParamConfig).filter_by(
  34. tenant_id=tenant_id,
  35. provider=datasource_provider_id.provider_name,
  36. plugin_id=datasource_provider_id.plugin_id,
  37. ).delete()
  38. session.commit()
  39. def decrypt_datasource_provider_credentials(
  40. self,
  41. tenant_id: str,
  42. datasource_provider: DatasourceProvider,
  43. plugin_id: str,
  44. provider: str,
  45. ) -> dict[str, Any]:
  46. encrypted_credentials = datasource_provider.encrypted_credentials
  47. credential_secret_variables = self.extract_secret_variables(
  48. tenant_id=tenant_id,
  49. provider_id=f"{plugin_id}/{provider}",
  50. credential_type=CredentialType.of(datasource_provider.auth_type),
  51. )
  52. decrypted_credentials = encrypted_credentials.copy()
  53. for key, value in decrypted_credentials.items():
  54. if key in credential_secret_variables:
  55. decrypted_credentials[key] = encrypter.decrypt_token(tenant_id, value)
  56. return decrypted_credentials
  57. def encrypt_datasource_provider_credentials(
  58. self,
  59. tenant_id: str,
  60. provider: str,
  61. plugin_id: str,
  62. raw_credentials: dict[str, Any],
  63. datasource_provider: DatasourceProvider,
  64. ) -> dict[str, Any]:
  65. provider_credential_secret_variables = self.extract_secret_variables(
  66. tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}", credential_type=datasource_provider.auth_type
  67. )
  68. encrypted_credentials = raw_credentials.copy()
  69. for key, value in encrypted_credentials.items():
  70. if key in provider_credential_secret_variables:
  71. encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
  72. return encrypted_credentials
  73. def get_datasource_credentials(
  74. self,
  75. tenant_id: str,
  76. provider: str,
  77. plugin_id: str,
  78. credential_id: str | None = None,
  79. ) -> dict[str, Any]:
  80. """
  81. get credential by id
  82. """
  83. with Session(db.engine) as session:
  84. if credential_id:
  85. datasource_provider = (
  86. session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
  87. )
  88. else:
  89. datasource_provider = (
  90. session.query(DatasourceProvider)
  91. .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
  92. .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
  93. .first()
  94. )
  95. if not datasource_provider:
  96. return {}
  97. # refresh the credentials
  98. if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
  99. decrypted_credentials = self.decrypt_datasource_provider_credentials(
  100. tenant_id=tenant_id,
  101. datasource_provider=datasource_provider,
  102. plugin_id=plugin_id,
  103. provider=provider,
  104. )
  105. datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
  106. provider_name = datasource_provider_id.provider_name
  107. redirect_uri = (
  108. f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
  109. f"{datasource_provider_id}/datasource/callback"
  110. )
  111. system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
  112. refreshed_credentials = OAuthHandler().refresh_credentials(
  113. tenant_id=tenant_id,
  114. user_id=current_user.id,
  115. plugin_id=datasource_provider_id.plugin_id,
  116. provider=provider_name,
  117. redirect_uri=redirect_uri,
  118. system_credentials=system_credentials or {},
  119. credentials=decrypted_credentials,
  120. )
  121. datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
  122. tenant_id=tenant_id,
  123. raw_credentials=refreshed_credentials.credentials,
  124. provider=provider,
  125. plugin_id=plugin_id,
  126. datasource_provider=datasource_provider,
  127. )
  128. datasource_provider.expires_at = refreshed_credentials.expires_at
  129. session.commit()
  130. return self.decrypt_datasource_provider_credentials(
  131. tenant_id=tenant_id,
  132. datasource_provider=datasource_provider,
  133. plugin_id=plugin_id,
  134. provider=provider,
  135. )
  136. def get_all_datasource_credentials_by_provider(
  137. self,
  138. tenant_id: str,
  139. provider: str,
  140. plugin_id: str,
  141. ) -> list[dict[str, Any]]:
  142. """
  143. get all datasource credentials by provider
  144. """
  145. with Session(db.engine) as session:
  146. datasource_providers = (
  147. session.query(DatasourceProvider)
  148. .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
  149. .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
  150. .all()
  151. )
  152. if not datasource_providers:
  153. return []
  154. # refresh the credentials
  155. real_credentials_list = []
  156. for datasource_provider in datasource_providers:
  157. decrypted_credentials = self.decrypt_datasource_provider_credentials(
  158. tenant_id=tenant_id,
  159. datasource_provider=datasource_provider,
  160. plugin_id=plugin_id,
  161. provider=provider,
  162. )
  163. datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
  164. provider_name = datasource_provider_id.provider_name
  165. redirect_uri = (
  166. f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
  167. f"{datasource_provider_id}/datasource/callback"
  168. )
  169. system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
  170. refreshed_credentials = OAuthHandler().refresh_credentials(
  171. tenant_id=tenant_id,
  172. user_id=current_user.id,
  173. plugin_id=datasource_provider_id.plugin_id,
  174. provider=provider_name,
  175. redirect_uri=redirect_uri,
  176. system_credentials=system_credentials or {},
  177. credentials=decrypted_credentials,
  178. )
  179. datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
  180. tenant_id=tenant_id,
  181. raw_credentials=refreshed_credentials.credentials,
  182. provider=provider,
  183. plugin_id=plugin_id,
  184. datasource_provider=datasource_provider,
  185. )
  186. datasource_provider.expires_at = refreshed_credentials.expires_at
  187. real_credentials = self.decrypt_datasource_provider_credentials(
  188. tenant_id=tenant_id,
  189. datasource_provider=datasource_provider,
  190. plugin_id=plugin_id,
  191. provider=provider,
  192. )
  193. real_credentials_list.append(real_credentials)
  194. session.commit()
  195. return real_credentials_list
  196. def update_datasource_provider_name(
  197. self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
  198. ):
  199. """
  200. update datasource provider name
  201. """
  202. with Session(db.engine) as session:
  203. target_provider = (
  204. session.query(DatasourceProvider)
  205. .filter_by(
  206. tenant_id=tenant_id,
  207. id=credential_id,
  208. provider=datasource_provider_id.provider_name,
  209. plugin_id=datasource_provider_id.plugin_id,
  210. )
  211. .first()
  212. )
  213. if target_provider is None:
  214. raise ValueError("provider not found")
  215. if target_provider.name == name:
  216. return
  217. # check name is exist
  218. if (
  219. session.query(DatasourceProvider)
  220. .filter_by(
  221. tenant_id=tenant_id,
  222. name=name,
  223. provider=datasource_provider_id.provider_name,
  224. plugin_id=datasource_provider_id.plugin_id,
  225. )
  226. .count()
  227. > 0
  228. ):
  229. raise ValueError("Authorization name is already exists")
  230. target_provider.name = name
  231. session.commit()
  232. return
  233. def set_default_datasource_provider(
  234. self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str
  235. ):
  236. """
  237. set default datasource provider
  238. """
  239. with Session(db.engine) as session:
  240. # get provider
  241. target_provider = (
  242. session.query(DatasourceProvider)
  243. .filter_by(
  244. tenant_id=tenant_id,
  245. id=credential_id,
  246. provider=datasource_provider_id.provider_name,
  247. plugin_id=datasource_provider_id.plugin_id,
  248. )
  249. .first()
  250. )
  251. if target_provider is None:
  252. raise ValueError("provider not found")
  253. # clear default provider
  254. session.query(DatasourceProvider).filter_by(
  255. tenant_id=tenant_id,
  256. provider=target_provider.provider,
  257. plugin_id=target_provider.plugin_id,
  258. is_default=True,
  259. ).update({"is_default": False})
  260. # set new default provider
  261. target_provider.is_default = True
  262. session.commit()
  263. return {"result": "success"}
  264. def setup_oauth_custom_client_params(
  265. self,
  266. tenant_id: str,
  267. datasource_provider_id: DatasourceProviderID,
  268. client_params: dict | None,
  269. enabled: bool | None,
  270. ):
  271. """
  272. setup oauth custom client params
  273. """
  274. if client_params is None and enabled is None:
  275. return
  276. with Session(db.engine) as session:
  277. tenant_oauth_client_params = (
  278. session.query(DatasourceOauthTenantParamConfig)
  279. .filter_by(
  280. tenant_id=tenant_id,
  281. provider=datasource_provider_id.provider_name,
  282. plugin_id=datasource_provider_id.plugin_id,
  283. )
  284. .first()
  285. )
  286. if not tenant_oauth_client_params:
  287. tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
  288. tenant_id=tenant_id,
  289. provider=datasource_provider_id.provider_name,
  290. plugin_id=datasource_provider_id.plugin_id,
  291. client_params={},
  292. enabled=False,
  293. )
  294. session.add(tenant_oauth_client_params)
  295. if client_params is not None:
  296. encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
  297. original_params = (
  298. encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
  299. )
  300. new_params: dict = {
  301. key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
  302. for key, value in client_params.items()
  303. }
  304. tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
  305. if enabled is not None:
  306. tenant_oauth_client_params.enabled = enabled
  307. session.commit()
  308. def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
  309. """
  310. check if system oauth params exist
  311. """
  312. with Session(db.engine).no_autoflush as session:
  313. return (
  314. session.query(DatasourceOauthParamConfig)
  315. .filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
  316. .first()
  317. is not None
  318. )
  319. def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
  320. """
  321. check if tenant oauth params is enabled
  322. """
  323. with Session(db.engine).no_autoflush as session:
  324. return (
  325. session.query(DatasourceOauthTenantParamConfig)
  326. .filter_by(
  327. tenant_id=tenant_id,
  328. provider=datasource_provider_id.provider_name,
  329. plugin_id=datasource_provider_id.plugin_id,
  330. enabled=True,
  331. )
  332. .count()
  333. > 0
  334. )
  335. def get_tenant_oauth_client(
  336. self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
  337. ) -> dict[str, Any] | None:
  338. """
  339. get tenant oauth client
  340. """
  341. with Session(db.engine).no_autoflush as session:
  342. tenant_oauth_client_params = (
  343. session.query(DatasourceOauthTenantParamConfig)
  344. .filter_by(
  345. tenant_id=tenant_id,
  346. provider=datasource_provider_id.provider_name,
  347. plugin_id=datasource_provider_id.plugin_id,
  348. )
  349. .first()
  350. )
  351. if tenant_oauth_client_params:
  352. encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
  353. if mask:
  354. return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
  355. else:
  356. return encrypter.decrypt(tenant_oauth_client_params.client_params)
  357. return None
  358. def get_oauth_encrypter(
  359. self, tenant_id: str, datasource_provider_id: DatasourceProviderID
  360. ) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
  361. """
  362. get oauth encrypter
  363. """
  364. datasource_provider = self.provider_manager.fetch_datasource_provider(
  365. tenant_id=tenant_id, provider_id=str(datasource_provider_id)
  366. )
  367. if not datasource_provider.declaration.oauth_schema:
  368. raise ValueError("Datasource provider oauth schema not found")
  369. client_schema = datasource_provider.declaration.oauth_schema.client_schema
  370. return create_provider_encrypter(
  371. tenant_id=tenant_id,
  372. config=[x.to_basic_provider_config() for x in client_schema],
  373. cache=NoOpProviderCredentialCache(),
  374. )
  375. def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
  376. """
  377. get oauth client
  378. """
  379. provider = datasource_provider_id.provider_name
  380. plugin_id = datasource_provider_id.plugin_id
  381. with Session(db.engine).no_autoflush as session:
  382. # get tenant oauth client params
  383. tenant_oauth_client_params = (
  384. session.query(DatasourceOauthTenantParamConfig)
  385. .filter_by(
  386. tenant_id=tenant_id,
  387. provider=provider,
  388. plugin_id=plugin_id,
  389. enabled=True,
  390. )
  391. .first()
  392. )
  393. if tenant_oauth_client_params:
  394. encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
  395. return encrypter.decrypt(tenant_oauth_client_params.client_params)
  396. provider_controller = self.provider_manager.fetch_datasource_provider(
  397. tenant_id=tenant_id, provider_id=str(datasource_provider_id)
  398. )
  399. is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
  400. if is_verified:
  401. # fallback to system oauth client params
  402. oauth_client_params = (
  403. session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
  404. )
  405. if oauth_client_params:
  406. return oauth_client_params.system_credentials
  407. raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
  408. @staticmethod
  409. def generate_next_datasource_provider_name(
  410. session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
  411. ) -> str:
  412. db_providers = (
  413. session.query(DatasourceProvider)
  414. .filter_by(
  415. tenant_id=tenant_id,
  416. provider=provider_id.provider_name,
  417. plugin_id=provider_id.plugin_id,
  418. )
  419. .all()
  420. )
  421. return generate_incremental_name(
  422. [provider.name for provider in db_providers],
  423. f"{credential_type.get_name()}",
  424. )
  425. def reauthorize_datasource_oauth_provider(
  426. self,
  427. name: str | None,
  428. tenant_id: str,
  429. provider_id: DatasourceProviderID,
  430. avatar_url: str | None,
  431. expire_at: int,
  432. credentials: dict,
  433. credential_id: str,
  434. ) -> None:
  435. """
  436. update datasource oauth provider
  437. """
  438. with Session(db.engine) as session:
  439. lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
  440. with redis_client.lock(lock, timeout=20):
  441. target_provider = (
  442. session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
  443. )
  444. if target_provider is None:
  445. raise ValueError("provider not found")
  446. db_provider_name = name
  447. if not db_provider_name:
  448. db_provider_name = target_provider.name
  449. else:
  450. name_conflict = (
  451. session.query(DatasourceProvider)
  452. .filter_by(
  453. tenant_id=tenant_id,
  454. name=db_provider_name,
  455. provider=provider_id.provider_name,
  456. plugin_id=provider_id.plugin_id,
  457. auth_type=CredentialType.OAUTH2.value,
  458. )
  459. .count()
  460. )
  461. if name_conflict > 0:
  462. db_provider_name = generate_incremental_name(
  463. [
  464. provider.name
  465. for provider in session.query(DatasourceProvider).filter_by(
  466. tenant_id=tenant_id,
  467. provider=provider_id.provider_name,
  468. plugin_id=provider_id.plugin_id,
  469. )
  470. ],
  471. db_provider_name,
  472. )
  473. provider_credential_secret_variables = self.extract_secret_variables(
  474. tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2
  475. )
  476. for key, value in credentials.items():
  477. if key in provider_credential_secret_variables:
  478. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  479. target_provider.expires_at = expire_at
  480. target_provider.encrypted_credentials = credentials
  481. target_provider.avatar_url = avatar_url or target_provider.avatar_url
  482. session.commit()
  483. def add_datasource_oauth_provider(
  484. self,
  485. name: str | None,
  486. tenant_id: str,
  487. provider_id: DatasourceProviderID,
  488. avatar_url: str | None,
  489. expire_at: int,
  490. credentials: dict,
  491. ) -> None:
  492. """
  493. add datasource oauth provider
  494. """
  495. credential_type = CredentialType.OAUTH2
  496. with Session(db.engine) as session:
  497. lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
  498. with redis_client.lock(lock, timeout=60):
  499. db_provider_name = name
  500. if not db_provider_name:
  501. db_provider_name = self.generate_next_datasource_provider_name(
  502. session=session,
  503. tenant_id=tenant_id,
  504. provider_id=provider_id,
  505. credential_type=credential_type,
  506. )
  507. else:
  508. if (
  509. session.query(DatasourceProvider)
  510. .filter_by(
  511. tenant_id=tenant_id,
  512. name=db_provider_name,
  513. provider=provider_id.provider_name,
  514. plugin_id=provider_id.plugin_id,
  515. auth_type=credential_type.value,
  516. )
  517. .count()
  518. > 0
  519. ):
  520. db_provider_name = generate_incremental_name(
  521. [
  522. provider.name
  523. for provider in session.query(DatasourceProvider).filter_by(
  524. tenant_id=tenant_id,
  525. provider=provider_id.provider_name,
  526. plugin_id=provider_id.plugin_id,
  527. )
  528. ],
  529. db_provider_name,
  530. )
  531. provider_credential_secret_variables = self.extract_secret_variables(
  532. tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type
  533. )
  534. for key, value in credentials.items():
  535. if key in provider_credential_secret_variables:
  536. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  537. datasource_provider = DatasourceProvider(
  538. tenant_id=tenant_id,
  539. name=db_provider_name,
  540. provider=provider_id.provider_name,
  541. plugin_id=provider_id.plugin_id,
  542. auth_type=credential_type.value,
  543. encrypted_credentials=credentials,
  544. avatar_url=avatar_url or "default",
  545. expires_at=expire_at,
  546. )
  547. session.add(datasource_provider)
  548. session.commit()
  549. def add_datasource_api_key_provider(
  550. self,
  551. name: str | None,
  552. tenant_id: str,
  553. provider_id: DatasourceProviderID,
  554. credentials: dict,
  555. ) -> None:
  556. """
  557. validate datasource provider credentials.
  558. :param tenant_id:
  559. :param provider:
  560. :param credentials:
  561. """
  562. provider_name = provider_id.provider_name
  563. plugin_id = provider_id.plugin_id
  564. with Session(db.engine) as session:
  565. lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
  566. with redis_client.lock(lock, timeout=20):
  567. db_provider_name = name or self.generate_next_datasource_provider_name(
  568. session=session,
  569. tenant_id=tenant_id,
  570. provider_id=provider_id,
  571. credential_type=CredentialType.API_KEY,
  572. )
  573. # check name is exist
  574. if (
  575. session.query(DatasourceProvider)
  576. .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
  577. .count()
  578. > 0
  579. ):
  580. raise ValueError("Authorization name is already exists")
  581. try:
  582. self.provider_manager.validate_provider_credentials(
  583. tenant_id=tenant_id,
  584. user_id=current_user.id,
  585. provider=provider_name,
  586. plugin_id=plugin_id,
  587. credentials=credentials,
  588. )
  589. except Exception as e:
  590. raise ValueError(f"Failed to validate credentials: {str(e)}")
  591. provider_credential_secret_variables = self.extract_secret_variables(
  592. tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY
  593. )
  594. for key, value in credentials.items():
  595. if key in provider_credential_secret_variables:
  596. # if send [__HIDDEN__] in secret input, it will be same as original value
  597. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  598. datasource_provider = DatasourceProvider(
  599. tenant_id=tenant_id,
  600. name=db_provider_name,
  601. provider=provider_name,
  602. plugin_id=plugin_id,
  603. auth_type=CredentialType.API_KEY.value,
  604. encrypted_credentials=credentials,
  605. )
  606. session.add(datasource_provider)
  607. session.commit()
  608. def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
  609. """
  610. Extract secret input form variables.
  611. :param credential_form_schemas:
  612. :return:
  613. """
  614. datasource_provider = self.provider_manager.fetch_datasource_provider(
  615. tenant_id=tenant_id, provider_id=provider_id
  616. )
  617. credential_form_schemas = []
  618. if credential_type == CredentialType.API_KEY:
  619. credential_form_schemas = list(datasource_provider.declaration.credentials_schema)
  620. elif credential_type == CredentialType.OAUTH2:
  621. if not datasource_provider.declaration.oauth_schema:
  622. raise ValueError("Datasource provider oauth schema not found")
  623. credential_form_schemas = list(datasource_provider.declaration.oauth_schema.credentials_schema)
  624. else:
  625. raise ValueError(f"Invalid credential type: {credential_type}")
  626. secret_input_form_variables = []
  627. for credential_form_schema in credential_form_schemas:
  628. if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
  629. secret_input_form_variables.append(credential_form_schema.name)
  630. return secret_input_form_variables
  631. def list_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  632. """
  633. list datasource credentials with obfuscated sensitive fields.
  634. :param tenant_id: workspace id
  635. :param provider_id: provider id
  636. :return:
  637. """
  638. # Get all provider configurations of the current workspace
  639. datasource_providers: list[DatasourceProvider] = (
  640. db.session.query(DatasourceProvider)
  641. .filter(
  642. DatasourceProvider.tenant_id == tenant_id,
  643. DatasourceProvider.provider == provider,
  644. DatasourceProvider.plugin_id == plugin_id,
  645. )
  646. .all()
  647. )
  648. if not datasource_providers:
  649. return []
  650. copy_credentials_list = []
  651. default_provider = (
  652. db.session.query(DatasourceProvider.id)
  653. .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
  654. .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
  655. .first()
  656. )
  657. default_provider_id = default_provider.id if default_provider else None
  658. for datasource_provider in datasource_providers:
  659. encrypted_credentials = datasource_provider.encrypted_credentials
  660. # Get provider credential secret variables
  661. credential_secret_variables = self.extract_secret_variables(
  662. tenant_id=tenant_id,
  663. provider_id=f"{plugin_id}/{provider}",
  664. credential_type=CredentialType.of(datasource_provider.auth_type),
  665. )
  666. # Obfuscate provider credentials
  667. copy_credentials = encrypted_credentials.copy()
  668. for key, value in copy_credentials.items():
  669. if key in credential_secret_variables:
  670. copy_credentials[key] = encrypter.obfuscated_token(value)
  671. copy_credentials_list.append(
  672. {
  673. "credential": copy_credentials,
  674. "type": datasource_provider.auth_type,
  675. "name": datasource_provider.name,
  676. "avatar_url": datasource_provider.avatar_url,
  677. "id": datasource_provider.id,
  678. "is_default": default_provider_id and datasource_provider.id == default_provider_id,
  679. }
  680. )
  681. return copy_credentials_list
  682. def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
  683. """
  684. get datasource credentials.
  685. :return:
  686. """
  687. # get all plugin providers
  688. manager = PluginDatasourceManager()
  689. datasources = manager.fetch_installed_datasource_providers(tenant_id)
  690. datasource_credentials = []
  691. for datasource in datasources:
  692. datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
  693. credentials = self.list_datasource_credentials(
  694. tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
  695. )
  696. redirect_uri = (
  697. f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
  698. )
  699. datasource_credentials.append(
  700. {
  701. "provider": datasource.provider,
  702. "plugin_id": datasource.plugin_id,
  703. "plugin_unique_identifier": datasource.plugin_unique_identifier,
  704. "icon": datasource.declaration.identity.icon,
  705. "name": datasource.declaration.identity.name.split("/")[-1],
  706. "label": datasource.declaration.identity.label.model_dump(),
  707. "description": datasource.declaration.identity.description.model_dump(),
  708. "author": datasource.declaration.identity.author,
  709. "credentials_list": credentials,
  710. "credential_schema": [
  711. credential.model_dump() for credential in datasource.declaration.credentials_schema
  712. ],
  713. "oauth_schema": {
  714. "client_schema": [
  715. client_schema.model_dump()
  716. for client_schema in datasource.declaration.oauth_schema.client_schema
  717. ],
  718. "credentials_schema": [
  719. credential_schema.model_dump()
  720. for credential_schema in datasource.declaration.oauth_schema.credentials_schema
  721. ],
  722. "oauth_custom_client_params": self.get_tenant_oauth_client(
  723. tenant_id, datasource_provider_id, mask=True
  724. ),
  725. "is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
  726. tenant_id, datasource_provider_id
  727. ),
  728. "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
  729. "redirect_uri": redirect_uri,
  730. }
  731. if datasource.declaration.oauth_schema
  732. else None,
  733. }
  734. )
  735. return datasource_credentials
  736. def get_hard_code_datasource_credentials(self, tenant_id: str) -> list[dict]:
  737. """
  738. get hard code datasource credentials.
  739. :return:
  740. """
  741. # get all plugin providers
  742. manager = PluginDatasourceManager()
  743. datasources = manager.fetch_installed_datasource_providers(tenant_id)
  744. datasource_credentials = []
  745. for datasource in datasources:
  746. if datasource.plugin_id in [
  747. "langgenius/firecrawl_datasource",
  748. "langgenius/notion_datasource",
  749. "langgenius/jina_datasource",
  750. ]:
  751. datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
  752. credentials = self.list_datasource_credentials(
  753. tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
  754. )
  755. redirect_uri = "{}/console/api/oauth/plugin/{}/datasource/callback".format(
  756. dify_config.CONSOLE_API_URL, datasource_provider_id
  757. )
  758. datasource_credentials.append(
  759. {
  760. "provider": datasource.provider,
  761. "plugin_id": datasource.plugin_id,
  762. "plugin_unique_identifier": datasource.plugin_unique_identifier,
  763. "icon": datasource.declaration.identity.icon,
  764. "name": datasource.declaration.identity.name.split("/")[-1],
  765. "label": datasource.declaration.identity.label.model_dump(),
  766. "description": datasource.declaration.identity.description.model_dump(),
  767. "author": datasource.declaration.identity.author,
  768. "credentials_list": credentials,
  769. "credential_schema": [
  770. credential.model_dump() for credential in datasource.declaration.credentials_schema
  771. ],
  772. "oauth_schema": {
  773. "client_schema": [
  774. client_schema.model_dump()
  775. for client_schema in datasource.declaration.oauth_schema.client_schema
  776. ],
  777. "credentials_schema": [
  778. credential_schema.model_dump()
  779. for credential_schema in datasource.declaration.oauth_schema.credentials_schema
  780. ],
  781. "oauth_custom_client_params": self.get_tenant_oauth_client(
  782. tenant_id, datasource_provider_id, mask=True
  783. ),
  784. "is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
  785. tenant_id, datasource_provider_id
  786. ),
  787. "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
  788. "redirect_uri": redirect_uri,
  789. }
  790. if datasource.declaration.oauth_schema
  791. else None,
  792. }
  793. )
  794. return datasource_credentials
  795. def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  796. """
  797. get datasource credentials.
  798. :param tenant_id: workspace id
  799. :param provider_id: provider id
  800. :return:
  801. """
  802. # Get all provider configurations of the current workspace
  803. datasource_providers: list[DatasourceProvider] = (
  804. db.session.query(DatasourceProvider)
  805. .filter(
  806. DatasourceProvider.tenant_id == tenant_id,
  807. DatasourceProvider.provider == provider,
  808. DatasourceProvider.plugin_id == plugin_id,
  809. )
  810. .all()
  811. )
  812. if not datasource_providers:
  813. return []
  814. copy_credentials_list = []
  815. for datasource_provider in datasource_providers:
  816. encrypted_credentials = datasource_provider.encrypted_credentials
  817. # Get provider credential secret variables
  818. credential_secret_variables = self.extract_secret_variables(
  819. tenant_id=tenant_id,
  820. provider_id=f"{plugin_id}/{provider}",
  821. credential_type=CredentialType.of(datasource_provider.auth_type),
  822. )
  823. # Obfuscate provider credentials
  824. copy_credentials = encrypted_credentials.copy()
  825. for key, value in copy_credentials.items():
  826. if key in credential_secret_variables:
  827. copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
  828. copy_credentials_list.append(
  829. {
  830. "credentials": copy_credentials,
  831. "type": datasource_provider.auth_type,
  832. }
  833. )
  834. return copy_credentials_list
  835. def update_datasource_credentials(
  836. self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
  837. ) -> None:
  838. """
  839. update datasource credentials.
  840. """
  841. with Session(db.engine) as session:
  842. datasource_provider = (
  843. session.query(DatasourceProvider)
  844. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  845. .first()
  846. )
  847. if not datasource_provider:
  848. raise ValueError("Datasource provider not found")
  849. # update name
  850. if name and name != datasource_provider.name:
  851. if (
  852. session.query(DatasourceProvider)
  853. .filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
  854. .count()
  855. > 0
  856. ):
  857. raise ValueError("Authorization name is already exists")
  858. datasource_provider.name = name
  859. # update credentials
  860. if credentials:
  861. secret_variables = self.extract_secret_variables(
  862. tenant_id=tenant_id,
  863. provider_id=f"{plugin_id}/{provider}",
  864. credential_type=CredentialType.of(datasource_provider.auth_type),
  865. )
  866. original_credentials = {
  867. key: value if key not in secret_variables else encrypter.decrypt_token(tenant_id, value)
  868. for key, value in datasource_provider.encrypted_credentials.items()
  869. }
  870. new_credentials = {
  871. key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
  872. for key, value in credentials.items()
  873. }
  874. try:
  875. self.provider_manager.validate_provider_credentials(
  876. tenant_id=tenant_id,
  877. user_id=current_user.id,
  878. provider=provider,
  879. plugin_id=plugin_id,
  880. credentials=new_credentials,
  881. )
  882. except Exception as e:
  883. raise ValueError(f"Failed to validate credentials: {str(e)}")
  884. encrypted_credentials = {}
  885. for key, value in new_credentials.items():
  886. if key in secret_variables:
  887. encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
  888. else:
  889. encrypted_credentials[key] = value
  890. datasource_provider.encrypted_credentials = encrypted_credentials
  891. session.commit()
  892. def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
  893. """
  894. remove datasource credentials.
  895. :param tenant_id: workspace id
  896. :param provider: provider name
  897. :param plugin_id: plugin id
  898. :return:
  899. """
  900. datasource_provider = (
  901. db.session.query(DatasourceProvider)
  902. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  903. .first()
  904. )
  905. if datasource_provider:
  906. db.session.delete(datasource_provider)
  907. db.session.commit()