選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

datasource_provider_service.py 42KB

5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
3ヶ月前
3ヶ月前
3ヶ月前
3ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
4ヶ月前
5ヶ月前
4ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
4ヶ月前
4ヶ月前
5ヶ月前
3ヶ月前
5ヶ月前
4ヶ月前
3ヶ月前
5ヶ月前
4ヶ月前
4ヶ月前
4ヶ月前
3ヶ月前
3ヶ月前
4ヶ月前
3ヶ月前
3ヶ月前
3ヶ月前
3ヶ月前
3ヶ月前
3ヶ月前
4ヶ月前
4ヶ月前
3ヶ月前
5ヶ月前
5ヶ月前
4ヶ月前
4ヶ月前
5ヶ月前
4ヶ月前
4ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
3ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974
  1. import logging
  2. import time
  3. from typing import Any
  4. from flask_login import current_user
  5. from sqlalchemy.orm import Session
  6. from configs import dify_config
  7. from constants import HIDDEN_VALUE, UNKNOWN_VALUE
  8. from core.helper import encrypter
  9. from core.helper.name_generator import generate_incremental_name
  10. from core.helper.provider_cache import NoOpProviderCredentialCache
  11. from core.model_runtime.entities.provider_entities import FormType
  12. from core.plugin.impl.datasource import PluginDatasourceManager
  13. from core.plugin.impl.oauth import OAuthHandler
  14. from core.tools.entities.tool_entities import CredentialType
  15. from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
  16. from extensions.ext_database import db
  17. from extensions.ext_redis import redis_client
  18. from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
  19. from models.provider_ids import DatasourceProviderID
  20. from services.plugin.plugin_service import PluginService
  21. logger = logging.getLogger(__name__)
  22. class DatasourceProviderService:
  23. """
  24. Model Provider Service
  25. """
  26. def __init__(self) -> None:
  27. self.provider_manager = PluginDatasourceManager()
  28. def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
  29. """
  30. remove oauth custom client params
  31. """
  32. with Session(db.engine) as session:
  33. session.query(DatasourceOauthTenantParamConfig).filter_by(
  34. tenant_id=tenant_id,
  35. provider=datasource_provider_id.provider_name,
  36. plugin_id=datasource_provider_id.plugin_id,
  37. ).delete()
  38. session.commit()
  39. def decrypt_datasource_provider_credentials(
  40. self,
  41. tenant_id: str,
  42. datasource_provider: DatasourceProvider,
  43. plugin_id: str,
  44. provider: str,
  45. ) -> dict[str, Any]:
  46. encrypted_credentials = datasource_provider.encrypted_credentials
  47. credential_secret_variables = self.extract_secret_variables(
  48. tenant_id=tenant_id,
  49. provider_id=f"{plugin_id}/{provider}",
  50. credential_type=CredentialType.of(datasource_provider.auth_type),
  51. )
  52. decrypted_credentials = encrypted_credentials.copy()
  53. for key, value in decrypted_credentials.items():
  54. if key in credential_secret_variables:
  55. decrypted_credentials[key] = encrypter.decrypt_token(tenant_id, value)
  56. return decrypted_credentials
  57. def encrypt_datasource_provider_credentials(
  58. self,
  59. tenant_id: str,
  60. provider: str,
  61. plugin_id: str,
  62. raw_credentials: dict[str, Any],
  63. datasource_provider: DatasourceProvider,
  64. ) -> dict[str, Any]:
  65. provider_credential_secret_variables = self.extract_secret_variables(
  66. tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}", credential_type=datasource_provider.auth_type
  67. )
  68. encrypted_credentials = raw_credentials.copy()
  69. for key, value in encrypted_credentials.items():
  70. if key in provider_credential_secret_variables:
  71. encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
  72. return encrypted_credentials
  73. def get_datasource_credentials(
  74. self,
  75. tenant_id: str,
  76. provider: str,
  77. plugin_id: str,
  78. credential_id: str | None = None,
  79. ) -> dict[str, Any]:
  80. """
  81. get credential by id
  82. """
  83. with Session(db.engine) as session:
  84. if credential_id:
  85. datasource_provider = (
  86. session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
  87. )
  88. else:
  89. datasource_provider = (
  90. session.query(DatasourceProvider)
  91. .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
  92. .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
  93. .first()
  94. )
  95. if not datasource_provider:
  96. return {}
  97. # refresh the credentials
  98. if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
  99. decrypted_credentials = self.decrypt_datasource_provider_credentials(
  100. tenant_id=tenant_id,
  101. datasource_provider=datasource_provider,
  102. plugin_id=plugin_id,
  103. provider=provider,
  104. )
  105. datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
  106. provider_name = datasource_provider_id.provider_name
  107. redirect_uri = (
  108. f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
  109. f"{datasource_provider_id}/datasource/callback"
  110. )
  111. system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
  112. refreshed_credentials = OAuthHandler().refresh_credentials(
  113. tenant_id=tenant_id,
  114. user_id=current_user.id,
  115. plugin_id=datasource_provider_id.plugin_id,
  116. provider=provider_name,
  117. redirect_uri=redirect_uri,
  118. system_credentials=system_credentials or {},
  119. credentials=decrypted_credentials,
  120. )
  121. datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
  122. tenant_id=tenant_id,
  123. raw_credentials=refreshed_credentials.credentials,
  124. provider=provider,
  125. plugin_id=plugin_id,
  126. datasource_provider=datasource_provider,
  127. )
  128. datasource_provider.expires_at = refreshed_credentials.expires_at
  129. session.commit()
  130. return self.decrypt_datasource_provider_credentials(
  131. tenant_id=tenant_id,
  132. datasource_provider=datasource_provider,
  133. plugin_id=plugin_id,
  134. provider=provider,
  135. )
  136. def get_all_datasource_credentials_by_provider(
  137. self,
  138. tenant_id: str,
  139. provider: str,
  140. plugin_id: str,
  141. ) -> list[dict[str, Any]]:
  142. """
  143. get all datasource credentials by provider
  144. """
  145. with Session(db.engine) as session:
  146. datasource_providers = (
  147. session.query(DatasourceProvider)
  148. .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
  149. .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
  150. .all()
  151. )
  152. if not datasource_providers:
  153. return []
  154. # refresh the credentials
  155. real_credentials_list = []
  156. for datasource_provider in datasource_providers:
  157. decrypted_credentials = self.decrypt_datasource_provider_credentials(
  158. tenant_id=tenant_id,
  159. datasource_provider=datasource_provider,
  160. plugin_id=plugin_id,
  161. provider=provider,
  162. )
  163. datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
  164. provider_name = datasource_provider_id.provider_name
  165. redirect_uri = (
  166. f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
  167. f"{datasource_provider_id}/datasource/callback"
  168. )
  169. system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
  170. refreshed_credentials = OAuthHandler().refresh_credentials(
  171. tenant_id=tenant_id,
  172. user_id=current_user.id,
  173. plugin_id=datasource_provider_id.plugin_id,
  174. provider=provider_name,
  175. redirect_uri=redirect_uri,
  176. system_credentials=system_credentials or {},
  177. credentials=decrypted_credentials,
  178. )
  179. datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
  180. tenant_id=tenant_id,
  181. raw_credentials=refreshed_credentials.credentials,
  182. provider=provider,
  183. plugin_id=plugin_id,
  184. datasource_provider=datasource_provider,
  185. )
  186. datasource_provider.expires_at = refreshed_credentials.expires_at
  187. real_credentials = self.decrypt_datasource_provider_credentials(
  188. tenant_id=tenant_id,
  189. datasource_provider=datasource_provider,
  190. plugin_id=plugin_id,
  191. provider=provider,
  192. )
  193. real_credentials_list.append(real_credentials)
  194. session.commit()
  195. return real_credentials_list
  196. def update_datasource_provider_name(
  197. self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
  198. ):
  199. """
  200. update datasource provider name
  201. """
  202. with Session(db.engine) as session:
  203. target_provider = (
  204. session.query(DatasourceProvider)
  205. .filter_by(
  206. tenant_id=tenant_id,
  207. id=credential_id,
  208. provider=datasource_provider_id.provider_name,
  209. plugin_id=datasource_provider_id.plugin_id,
  210. )
  211. .first()
  212. )
  213. if target_provider is None:
  214. raise ValueError("provider not found")
  215. if target_provider.name == name:
  216. return
  217. # check name is exist
  218. if (
  219. session.query(DatasourceProvider)
  220. .filter_by(
  221. tenant_id=tenant_id,
  222. name=name,
  223. provider=datasource_provider_id.provider_name,
  224. plugin_id=datasource_provider_id.plugin_id,
  225. )
  226. .count()
  227. > 0
  228. ):
  229. raise ValueError("Authorization name is already exists")
  230. target_provider.name = name
  231. session.commit()
  232. return
  233. def set_default_datasource_provider(
  234. self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str
  235. ):
  236. """
  237. set default datasource provider
  238. """
  239. with Session(db.engine) as session:
  240. # get provider
  241. target_provider = (
  242. session.query(DatasourceProvider)
  243. .filter_by(
  244. tenant_id=tenant_id,
  245. id=credential_id,
  246. provider=datasource_provider_id.provider_name,
  247. plugin_id=datasource_provider_id.plugin_id,
  248. )
  249. .first()
  250. )
  251. if target_provider is None:
  252. raise ValueError("provider not found")
  253. # clear default provider
  254. session.query(DatasourceProvider).filter_by(
  255. tenant_id=tenant_id,
  256. provider=target_provider.provider,
  257. plugin_id=target_provider.plugin_id,
  258. is_default=True,
  259. ).update({"is_default": False})
  260. # set new default provider
  261. target_provider.is_default = True
  262. session.commit()
  263. return {"result": "success"}
  264. def setup_oauth_custom_client_params(
  265. self,
  266. tenant_id: str,
  267. datasource_provider_id: DatasourceProviderID,
  268. client_params: dict | None,
  269. enabled: bool | None,
  270. ):
  271. """
  272. setup oauth custom client params
  273. """
  274. if client_params is None and enabled is None:
  275. return
  276. with Session(db.engine) as session:
  277. tenant_oauth_client_params = (
  278. session.query(DatasourceOauthTenantParamConfig)
  279. .filter_by(
  280. tenant_id=tenant_id,
  281. provider=datasource_provider_id.provider_name,
  282. plugin_id=datasource_provider_id.plugin_id,
  283. )
  284. .first()
  285. )
  286. if not tenant_oauth_client_params:
  287. tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
  288. tenant_id=tenant_id,
  289. provider=datasource_provider_id.provider_name,
  290. plugin_id=datasource_provider_id.plugin_id,
  291. client_params={},
  292. enabled=False,
  293. )
  294. session.add(tenant_oauth_client_params)
  295. if client_params is not None:
  296. encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
  297. original_params = (
  298. encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
  299. )
  300. new_params: dict = {
  301. key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
  302. for key, value in client_params.items()
  303. }
  304. tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
  305. if enabled is not None:
  306. tenant_oauth_client_params.enabled = enabled
  307. session.commit()
  308. def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
  309. """
  310. check if system oauth params exist
  311. """
  312. with Session(db.engine).no_autoflush as session:
  313. return (
  314. session.query(DatasourceOauthParamConfig)
  315. .filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
  316. .first()
  317. is not None
  318. )
  319. def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
  320. """
  321. check if tenant oauth params is enabled
  322. """
  323. with Session(db.engine).no_autoflush as session:
  324. return (
  325. session.query(DatasourceOauthTenantParamConfig)
  326. .filter_by(
  327. tenant_id=tenant_id,
  328. provider=datasource_provider_id.provider_name,
  329. plugin_id=datasource_provider_id.plugin_id,
  330. enabled=True,
  331. )
  332. .count()
  333. > 0
  334. )
  335. def get_tenant_oauth_client(
  336. self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
  337. ) -> dict[str, Any] | None:
  338. """
  339. get tenant oauth client
  340. """
  341. with Session(db.engine).no_autoflush as session:
  342. tenant_oauth_client_params = (
  343. session.query(DatasourceOauthTenantParamConfig)
  344. .filter_by(
  345. tenant_id=tenant_id,
  346. provider=datasource_provider_id.provider_name,
  347. plugin_id=datasource_provider_id.plugin_id,
  348. )
  349. .first()
  350. )
  351. if tenant_oauth_client_params:
  352. encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
  353. if mask:
  354. return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
  355. else:
  356. return encrypter.decrypt(tenant_oauth_client_params.client_params)
  357. return None
  358. def get_oauth_encrypter(
  359. self, tenant_id: str, datasource_provider_id: DatasourceProviderID
  360. ) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
  361. """
  362. get oauth encrypter
  363. """
  364. datasource_provider = self.provider_manager.fetch_datasource_provider(
  365. tenant_id=tenant_id, provider_id=str(datasource_provider_id)
  366. )
  367. if not datasource_provider.declaration.oauth_schema:
  368. raise ValueError("Datasource provider oauth schema not found")
  369. client_schema = datasource_provider.declaration.oauth_schema.client_schema
  370. return create_provider_encrypter(
  371. tenant_id=tenant_id,
  372. config=[x.to_basic_provider_config() for x in client_schema],
  373. cache=NoOpProviderCredentialCache(),
  374. )
  375. def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
  376. """
  377. get oauth client
  378. """
  379. provider = datasource_provider_id.provider_name
  380. plugin_id = datasource_provider_id.plugin_id
  381. with Session(db.engine).no_autoflush as session:
  382. # get tenant oauth client params
  383. tenant_oauth_client_params = (
  384. session.query(DatasourceOauthTenantParamConfig)
  385. .filter_by(
  386. tenant_id=tenant_id,
  387. provider=provider,
  388. plugin_id=plugin_id,
  389. enabled=True,
  390. )
  391. .first()
  392. )
  393. if tenant_oauth_client_params:
  394. encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
  395. return encrypter.decrypt(tenant_oauth_client_params.client_params)
  396. provider_controller = self.provider_manager.fetch_datasource_provider(
  397. tenant_id=tenant_id, provider_id=str(datasource_provider_id)
  398. )
  399. is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
  400. if is_verified:
  401. # fallback to system oauth client params
  402. oauth_client_params = (
  403. session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
  404. )
  405. if oauth_client_params:
  406. return oauth_client_params.system_credentials
  407. raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
  408. @staticmethod
  409. def generate_next_datasource_provider_name(
  410. session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
  411. ) -> str:
  412. db_providers = (
  413. session.query(DatasourceProvider)
  414. .filter_by(
  415. tenant_id=tenant_id,
  416. provider=provider_id.provider_name,
  417. plugin_id=provider_id.plugin_id,
  418. )
  419. .all()
  420. )
  421. return generate_incremental_name(
  422. [provider.name for provider in db_providers],
  423. f"{credential_type.get_name()}",
  424. )
  425. def reauthorize_datasource_oauth_provider(
  426. self,
  427. name: str | None,
  428. tenant_id: str,
  429. provider_id: DatasourceProviderID,
  430. avatar_url: str | None,
  431. expire_at: int,
  432. credentials: dict,
  433. credential_id: str,
  434. ) -> None:
  435. """
  436. update datasource oauth provider
  437. """
  438. with Session(db.engine) as session:
  439. lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
  440. with redis_client.lock(lock, timeout=20):
  441. target_provider = (
  442. session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
  443. )
  444. if target_provider is None:
  445. raise ValueError("provider not found")
  446. db_provider_name = name
  447. if not db_provider_name:
  448. db_provider_name = target_provider.name
  449. else:
  450. name_conflict = (
  451. session.query(DatasourceProvider)
  452. .filter_by(
  453. tenant_id=tenant_id,
  454. name=db_provider_name,
  455. provider=provider_id.provider_name,
  456. plugin_id=provider_id.plugin_id,
  457. auth_type=CredentialType.OAUTH2.value,
  458. )
  459. .count()
  460. )
  461. if name_conflict > 0:
  462. db_provider_name = generate_incremental_name(
  463. [
  464. provider.name
  465. for provider in session.query(DatasourceProvider).filter_by(
  466. tenant_id=tenant_id,
  467. provider=provider_id.provider_name,
  468. plugin_id=provider_id.plugin_id,
  469. )
  470. ],
  471. db_provider_name,
  472. )
  473. provider_credential_secret_variables = self.extract_secret_variables(
  474. tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2
  475. )
  476. for key, value in credentials.items():
  477. if key in provider_credential_secret_variables:
  478. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  479. target_provider.expires_at = expire_at
  480. target_provider.encrypted_credentials = credentials
  481. target_provider.avatar_url = avatar_url or target_provider.avatar_url
  482. session.commit()
  483. def add_datasource_oauth_provider(
  484. self,
  485. name: str | None,
  486. tenant_id: str,
  487. provider_id: DatasourceProviderID,
  488. avatar_url: str | None,
  489. expire_at: int,
  490. credentials: dict,
  491. ) -> None:
  492. """
  493. add datasource oauth provider
  494. """
  495. credential_type = CredentialType.OAUTH2
  496. with Session(db.engine) as session:
  497. lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
  498. with redis_client.lock(lock, timeout=60):
  499. db_provider_name = name
  500. if not db_provider_name:
  501. db_provider_name = self.generate_next_datasource_provider_name(
  502. session=session,
  503. tenant_id=tenant_id,
  504. provider_id=provider_id,
  505. credential_type=credential_type,
  506. )
  507. else:
  508. if (
  509. session.query(DatasourceProvider)
  510. .filter_by(
  511. tenant_id=tenant_id,
  512. name=db_provider_name,
  513. provider=provider_id.provider_name,
  514. plugin_id=provider_id.plugin_id,
  515. auth_type=credential_type.value,
  516. )
  517. .count()
  518. > 0
  519. ):
  520. db_provider_name = generate_incremental_name(
  521. [
  522. provider.name
  523. for provider in session.query(DatasourceProvider).filter_by(
  524. tenant_id=tenant_id,
  525. provider=provider_id.provider_name,
  526. plugin_id=provider_id.plugin_id,
  527. )
  528. ],
  529. db_provider_name,
  530. )
  531. provider_credential_secret_variables = self.extract_secret_variables(
  532. tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type
  533. )
  534. for key, value in credentials.items():
  535. if key in provider_credential_secret_variables:
  536. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  537. datasource_provider = DatasourceProvider(
  538. tenant_id=tenant_id,
  539. name=db_provider_name,
  540. provider=provider_id.provider_name,
  541. plugin_id=provider_id.plugin_id,
  542. auth_type=credential_type.value,
  543. encrypted_credentials=credentials,
  544. avatar_url=avatar_url or "default",
  545. expires_at=expire_at,
  546. )
  547. session.add(datasource_provider)
  548. session.commit()
  549. def add_datasource_api_key_provider(
  550. self,
  551. name: str | None,
  552. tenant_id: str,
  553. provider_id: DatasourceProviderID,
  554. credentials: dict,
  555. ) -> None:
  556. """
  557. validate datasource provider credentials.
  558. :param tenant_id:
  559. :param provider:
  560. :param credentials:
  561. """
  562. provider_name = provider_id.provider_name
  563. plugin_id = provider_id.plugin_id
  564. with Session(db.engine) as session:
  565. lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
  566. with redis_client.lock(lock, timeout=20):
  567. db_provider_name = name or self.generate_next_datasource_provider_name(
  568. session=session,
  569. tenant_id=tenant_id,
  570. provider_id=provider_id,
  571. credential_type=CredentialType.API_KEY,
  572. )
  573. # check name is exist
  574. if (
  575. session.query(DatasourceProvider)
  576. .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
  577. .count()
  578. > 0
  579. ):
  580. raise ValueError("Authorization name is already exists")
  581. try:
  582. self.provider_manager.validate_provider_credentials(
  583. tenant_id=tenant_id,
  584. user_id=current_user.id,
  585. provider=provider_name,
  586. plugin_id=plugin_id,
  587. credentials=credentials,
  588. )
  589. except Exception as e:
  590. raise ValueError(f"Failed to validate credentials: {str(e)}")
  591. provider_credential_secret_variables = self.extract_secret_variables(
  592. tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY
  593. )
  594. for key, value in credentials.items():
  595. if key in provider_credential_secret_variables:
  596. # if send [__HIDDEN__] in secret input, it will be same as original value
  597. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  598. datasource_provider = DatasourceProvider(
  599. tenant_id=tenant_id,
  600. name=db_provider_name,
  601. provider=provider_name,
  602. plugin_id=plugin_id,
  603. auth_type=CredentialType.API_KEY.value,
  604. encrypted_credentials=credentials,
  605. )
  606. session.add(datasource_provider)
  607. session.commit()
  608. def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
  609. """
  610. Extract secret input form variables.
  611. :param credential_form_schemas:
  612. :return:
  613. """
  614. datasource_provider = self.provider_manager.fetch_datasource_provider(
  615. tenant_id=tenant_id, provider_id=provider_id
  616. )
  617. credential_form_schemas = []
  618. if credential_type == CredentialType.API_KEY:
  619. credential_form_schemas = list(datasource_provider.declaration.credentials_schema)
  620. elif credential_type == CredentialType.OAUTH2:
  621. if not datasource_provider.declaration.oauth_schema:
  622. raise ValueError("Datasource provider oauth schema not found")
  623. credential_form_schemas = list(datasource_provider.declaration.oauth_schema.credentials_schema)
  624. else:
  625. raise ValueError(f"Invalid credential type: {credential_type}")
  626. secret_input_form_variables = []
  627. for credential_form_schema in credential_form_schemas:
  628. if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
  629. secret_input_form_variables.append(credential_form_schema.name)
  630. return secret_input_form_variables
  631. def list_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  632. """
  633. list datasource credentials with obfuscated sensitive fields.
  634. :param tenant_id: workspace id
  635. :param provider_id: provider id
  636. :return:
  637. """
  638. # Get all provider configurations of the current workspace
  639. datasource_providers: list[DatasourceProvider] = (
  640. db.session.query(DatasourceProvider)
  641. .filter(
  642. DatasourceProvider.tenant_id == tenant_id,
  643. DatasourceProvider.provider == provider,
  644. DatasourceProvider.plugin_id == plugin_id,
  645. )
  646. .all()
  647. )
  648. if not datasource_providers:
  649. return []
  650. copy_credentials_list = []
  651. default_provider = (
  652. db.session.query(DatasourceProvider.id)
  653. .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
  654. .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
  655. .first()
  656. )
  657. default_provider_id = default_provider.id if default_provider else None
  658. for datasource_provider in datasource_providers:
  659. encrypted_credentials = datasource_provider.encrypted_credentials
  660. # Get provider credential secret variables
  661. credential_secret_variables = self.extract_secret_variables(
  662. tenant_id=tenant_id,
  663. provider_id=f"{plugin_id}/{provider}",
  664. credential_type=CredentialType.of(datasource_provider.auth_type),
  665. )
  666. # Obfuscate provider credentials
  667. copy_credentials = encrypted_credentials.copy()
  668. for key, value in copy_credentials.items():
  669. if key in credential_secret_variables:
  670. copy_credentials[key] = encrypter.obfuscated_token(value)
  671. copy_credentials_list.append(
  672. {
  673. "credential": copy_credentials,
  674. "type": datasource_provider.auth_type,
  675. "name": datasource_provider.name,
  676. "avatar_url": datasource_provider.avatar_url,
  677. "id": datasource_provider.id,
  678. "is_default": default_provider_id and datasource_provider.id == default_provider_id,
  679. }
  680. )
  681. return copy_credentials_list
  682. def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
  683. """
  684. get datasource credentials.
  685. :return:
  686. """
  687. # get all plugin providers
  688. manager = PluginDatasourceManager()
  689. datasources = manager.fetch_installed_datasource_providers(tenant_id)
  690. datasource_credentials = []
  691. for datasource in datasources:
  692. datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
  693. credentials = self.list_datasource_credentials(
  694. tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
  695. )
  696. redirect_uri = (
  697. f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
  698. )
  699. datasource_credentials.append(
  700. {
  701. "provider": datasource.provider,
  702. "plugin_id": datasource.plugin_id,
  703. "plugin_unique_identifier": datasource.plugin_unique_identifier,
  704. "icon": datasource.declaration.identity.icon,
  705. "name": datasource.declaration.identity.name.split("/")[-1],
  706. "label": datasource.declaration.identity.label.model_dump(),
  707. "description": datasource.declaration.identity.description.model_dump(),
  708. "author": datasource.declaration.identity.author,
  709. "credentials_list": credentials,
  710. "credential_schema": [
  711. credential.model_dump() for credential in datasource.declaration.credentials_schema
  712. ],
  713. "oauth_schema": {
  714. "client_schema": [
  715. client_schema.model_dump()
  716. for client_schema in datasource.declaration.oauth_schema.client_schema
  717. ],
  718. "credentials_schema": [
  719. credential_schema.model_dump()
  720. for credential_schema in datasource.declaration.oauth_schema.credentials_schema
  721. ],
  722. "oauth_custom_client_params": self.get_tenant_oauth_client(
  723. tenant_id, datasource_provider_id, mask=True
  724. ),
  725. "is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
  726. tenant_id, datasource_provider_id
  727. ),
  728. "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
  729. "redirect_uri": redirect_uri,
  730. }
  731. if datasource.declaration.oauth_schema
  732. else None,
  733. }
  734. )
  735. return datasource_credentials
  736. def get_hard_code_datasource_credentials(self, tenant_id: str) -> list[dict]:
  737. """
  738. get hard code datasource credentials.
  739. :return:
  740. """
  741. # get all plugin providers
  742. manager = PluginDatasourceManager()
  743. datasources = manager.fetch_installed_datasource_providers(tenant_id)
  744. datasource_credentials = []
  745. for datasource in datasources:
  746. if datasource.plugin_id in [
  747. "langgenius/firecrawl_datasource",
  748. "langgenius/notion_datasource",
  749. "langgenius/jina_datasource",
  750. ]:
  751. datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
  752. credentials = self.list_datasource_credentials(
  753. tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
  754. )
  755. redirect_uri = "{}/console/api/oauth/plugin/{}/datasource/callback".format(
  756. dify_config.CONSOLE_API_URL, datasource_provider_id
  757. )
  758. datasource_credentials.append(
  759. {
  760. "provider": datasource.provider,
  761. "plugin_id": datasource.plugin_id,
  762. "plugin_unique_identifier": datasource.plugin_unique_identifier,
  763. "icon": datasource.declaration.identity.icon,
  764. "name": datasource.declaration.identity.name.split("/")[-1],
  765. "label": datasource.declaration.identity.label.model_dump(),
  766. "description": datasource.declaration.identity.description.model_dump(),
  767. "author": datasource.declaration.identity.author,
  768. "credentials_list": credentials,
  769. "credential_schema": [
  770. credential.model_dump() for credential in datasource.declaration.credentials_schema
  771. ],
  772. "oauth_schema": {
  773. "client_schema": [
  774. client_schema.model_dump()
  775. for client_schema in datasource.declaration.oauth_schema.client_schema
  776. ],
  777. "credentials_schema": [
  778. credential_schema.model_dump()
  779. for credential_schema in datasource.declaration.oauth_schema.credentials_schema
  780. ],
  781. "oauth_custom_client_params": self.get_tenant_oauth_client(
  782. tenant_id, datasource_provider_id, mask=True
  783. ),
  784. "is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
  785. tenant_id, datasource_provider_id
  786. ),
  787. "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
  788. "redirect_uri": redirect_uri,
  789. }
  790. if datasource.declaration.oauth_schema
  791. else None,
  792. }
  793. )
  794. return datasource_credentials
  795. def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  796. """
  797. get datasource credentials.
  798. :param tenant_id: workspace id
  799. :param provider_id: provider id
  800. :return:
  801. """
  802. # Get all provider configurations of the current workspace
  803. datasource_providers: list[DatasourceProvider] = (
  804. db.session.query(DatasourceProvider)
  805. .filter(
  806. DatasourceProvider.tenant_id == tenant_id,
  807. DatasourceProvider.provider == provider,
  808. DatasourceProvider.plugin_id == plugin_id,
  809. )
  810. .all()
  811. )
  812. if not datasource_providers:
  813. return []
  814. copy_credentials_list = []
  815. for datasource_provider in datasource_providers:
  816. encrypted_credentials = datasource_provider.encrypted_credentials
  817. # Get provider credential secret variables
  818. credential_secret_variables = self.extract_secret_variables(
  819. tenant_id=tenant_id,
  820. provider_id=f"{plugin_id}/{provider}",
  821. credential_type=CredentialType.of(datasource_provider.auth_type),
  822. )
  823. # Obfuscate provider credentials
  824. copy_credentials = encrypted_credentials.copy()
  825. for key, value in copy_credentials.items():
  826. if key in credential_secret_variables:
  827. copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
  828. copy_credentials_list.append(
  829. {
  830. "credentials": copy_credentials,
  831. "type": datasource_provider.auth_type,
  832. }
  833. )
  834. return copy_credentials_list
  835. def update_datasource_credentials(
  836. self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
  837. ) -> None:
  838. """
  839. update datasource credentials.
  840. """
  841. with Session(db.engine) as session:
  842. datasource_provider = (
  843. session.query(DatasourceProvider)
  844. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  845. .first()
  846. )
  847. if not datasource_provider:
  848. raise ValueError("Datasource provider not found")
  849. # update name
  850. if name and name != datasource_provider.name:
  851. if (
  852. session.query(DatasourceProvider)
  853. .filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
  854. .count()
  855. > 0
  856. ):
  857. raise ValueError("Authorization name is already exists")
  858. datasource_provider.name = name
  859. # update credentials
  860. if credentials:
  861. secret_variables = self.extract_secret_variables(
  862. tenant_id=tenant_id,
  863. provider_id=f"{plugin_id}/{provider}",
  864. credential_type=CredentialType.of(datasource_provider.auth_type),
  865. )
  866. original_credentials = {
  867. key: value if key not in secret_variables else encrypter.decrypt_token(tenant_id, value)
  868. for key, value in datasource_provider.encrypted_credentials.items()
  869. }
  870. new_credentials = {
  871. key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
  872. for key, value in credentials.items()
  873. }
  874. try:
  875. self.provider_manager.validate_provider_credentials(
  876. tenant_id=tenant_id,
  877. user_id=current_user.id,
  878. provider=provider,
  879. plugin_id=plugin_id,
  880. credentials=new_credentials,
  881. )
  882. except Exception as e:
  883. raise ValueError(f"Failed to validate credentials: {str(e)}")
  884. encrypted_credentials = {}
  885. for key, value in new_credentials.items():
  886. if key in secret_variables:
  887. encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
  888. else:
  889. encrypted_credentials[key] = value
  890. datasource_provider.encrypted_credentials = encrypted_credentials
  891. session.commit()
  892. def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
  893. """
  894. remove datasource credentials.
  895. :param tenant_id: workspace id
  896. :param provider: provider name
  897. :param plugin_id: plugin id
  898. :return:
  899. """
  900. datasource_provider = (
  901. db.session.query(DatasourceProvider)
  902. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  903. .first()
  904. )
  905. if datasource_provider:
  906. db.session.delete(datasource_provider)
  907. db.session.commit()