Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

datasource_provider_service.py 38KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882
  1. import logging
  2. from typing import Any
  3. from flask_login import current_user
  4. from sqlalchemy.orm import Session
  5. from configs import dify_config
  6. from constants import HIDDEN_VALUE, UNKNOWN_VALUE
  7. from core.helper import encrypter
  8. from core.helper.name_generator import generate_incremental_name
  9. from core.helper.provider_cache import NoOpProviderCredentialCache
  10. from core.model_runtime.entities.provider_entities import FormType
  11. from core.plugin.entities.plugin import DatasourceProviderID
  12. from core.plugin.impl.datasource import PluginDatasourceManager
  13. from core.tools.entities.tool_entities import CredentialType
  14. from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
  15. from extensions.ext_database import db
  16. from extensions.ext_redis import redis_client
  17. from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
  18. logger = logging.getLogger(__name__)
  19. class DatasourceProviderService:
  20. """
  21. Model Provider Service
  22. """
  23. def __init__(self) -> None:
  24. self.provider_manager = PluginDatasourceManager()
  25. def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
  26. """
  27. remove oauth custom client params
  28. """
  29. with Session(db.engine) as session:
  30. session.query(DatasourceOauthTenantParamConfig).filter_by(
  31. tenant_id=tenant_id,
  32. provider=datasource_provider_id.provider_name,
  33. plugin_id=datasource_provider_id.plugin_id,
  34. ).delete()
  35. session.commit()
  36. def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
  37. """
  38. get default credentials
  39. """
  40. with Session(db.engine) as session:
  41. datasource_provider = (
  42. session.query(DatasourceProvider)
  43. .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
  44. .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
  45. .first()
  46. )
  47. if not datasource_provider:
  48. return {}
  49. encrypted_credentials = datasource_provider.encrypted_credentials
  50. # Get provider credential secret variables
  51. credential_secret_variables = self.extract_secret_variables(
  52. tenant_id=tenant_id,
  53. provider_id=f"{plugin_id}/{provider}",
  54. credential_type=CredentialType.of(datasource_provider.auth_type),
  55. )
  56. # Obfuscate provider credentials
  57. copy_credentials = encrypted_credentials.copy()
  58. for key, value in copy_credentials.items():
  59. if key in credential_secret_variables:
  60. copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
  61. return copy_credentials
  62. def get_real_credential_by_id(
  63. self, tenant_id: str, credential_id: str, provider: str, plugin_id: str
  64. ) -> dict[str, Any]:
  65. """
  66. get credential by id
  67. """
  68. with Session(db.engine) as session:
  69. datasource_provider = (
  70. session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
  71. )
  72. if not datasource_provider:
  73. return {}
  74. encrypted_credentials = datasource_provider.encrypted_credentials
  75. # Get provider credential secret variables
  76. credential_secret_variables = self.extract_secret_variables(
  77. tenant_id=tenant_id,
  78. provider_id=f"{plugin_id}/{provider}",
  79. credential_type=CredentialType.of(datasource_provider.auth_type),
  80. )
  81. # Obfuscate provider credentials
  82. copy_credentials = encrypted_credentials.copy()
  83. for key, value in copy_credentials.items():
  84. if key in credential_secret_variables:
  85. copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
  86. return copy_credentials
  87. def get_default_real_credential(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
  88. """
  89. get default credential
  90. """
  91. with Session(db.engine) as session:
  92. datasource_provider = (
  93. session.query(DatasourceProvider)
  94. .filter_by(tenant_id=tenant_id, is_default=True, provider=provider, plugin_id=plugin_id)
  95. .first()
  96. )
  97. if not datasource_provider:
  98. return {}
  99. encrypted_credentials = datasource_provider.encrypted_credentials
  100. # Get provider credential secret variables
  101. credential_secret_variables = self.extract_secret_variables(
  102. tenant_id=tenant_id,
  103. provider_id=f"{plugin_id}/{provider}",
  104. credential_type=CredentialType.of(datasource_provider.auth_type),
  105. )
  106. # Obfuscate provider credentials
  107. copy_credentials = encrypted_credentials.copy()
  108. for key, value in copy_credentials.items():
  109. if key in credential_secret_variables:
  110. copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
  111. return copy_credentials
  112. def update_datasource_provider_name(
  113. self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
  114. ):
  115. """
  116. update datasource provider name
  117. """
  118. with Session(db.engine) as session:
  119. target_provider = (
  120. session.query(DatasourceProvider)
  121. .filter_by(
  122. tenant_id=tenant_id,
  123. id=credential_id,
  124. provider=datasource_provider_id.provider_name,
  125. plugin_id=datasource_provider_id.plugin_id,
  126. )
  127. .first()
  128. )
  129. if target_provider is None:
  130. raise ValueError("provider not found")
  131. if target_provider.name == name:
  132. return
  133. # check name is exist
  134. if (
  135. session.query(DatasourceProvider)
  136. .filter_by(
  137. tenant_id=tenant_id,
  138. name=name,
  139. provider=datasource_provider_id.provider_name,
  140. plugin_id=datasource_provider_id.plugin_id,
  141. )
  142. .count()
  143. > 0
  144. ):
  145. raise ValueError("Authorization name is already exists")
  146. target_provider.name = name
  147. session.commit()
  148. return
  149. def set_default_datasource_provider(
  150. self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str
  151. ):
  152. """
  153. set default datasource provider
  154. """
  155. with Session(db.engine) as session:
  156. # get provider
  157. target_provider = (
  158. session.query(DatasourceProvider)
  159. .filter_by(
  160. tenant_id=tenant_id,
  161. id=credential_id,
  162. provider=datasource_provider_id.provider_name,
  163. plugin_id=datasource_provider_id.plugin_id,
  164. )
  165. .first()
  166. )
  167. if target_provider is None:
  168. raise ValueError("provider not found")
  169. # clear default provider
  170. session.query(DatasourceProvider).filter_by(
  171. tenant_id=tenant_id,
  172. provider=target_provider.provider,
  173. plugin_id=target_provider.plugin_id,
  174. is_default=True,
  175. ).update({"is_default": False})
  176. # set new default provider
  177. target_provider.is_default = True
  178. session.commit()
  179. return {"result": "success"}
  180. def setup_oauth_custom_client_params(
  181. self,
  182. tenant_id: str,
  183. datasource_provider_id: DatasourceProviderID,
  184. client_params: dict | None,
  185. enabled: bool | None,
  186. ):
  187. """
  188. setup oauth custom client params
  189. """
  190. if client_params is None and enabled is None:
  191. return
  192. with Session(db.engine) as session:
  193. tenant_oauth_client_params = (
  194. session.query(DatasourceOauthTenantParamConfig)
  195. .filter_by(
  196. tenant_id=tenant_id,
  197. provider=datasource_provider_id.provider_name,
  198. plugin_id=datasource_provider_id.plugin_id,
  199. )
  200. .first()
  201. )
  202. if not tenant_oauth_client_params:
  203. tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
  204. tenant_id=tenant_id,
  205. provider=datasource_provider_id.provider_name,
  206. plugin_id=datasource_provider_id.plugin_id,
  207. client_params={},
  208. enabled=False,
  209. )
  210. session.add(tenant_oauth_client_params)
  211. if client_params is not None:
  212. encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
  213. original_params = (
  214. encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
  215. )
  216. new_params: dict = {
  217. key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
  218. for key, value in client_params.items()
  219. }
  220. tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
  221. if enabled is not None:
  222. tenant_oauth_client_params.enabled = enabled
  223. session.commit()
  224. def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
  225. """
  226. check if system oauth params exist
  227. """
  228. with Session(db.engine).no_autoflush as session:
  229. return (
  230. session.query(DatasourceOauthParamConfig)
  231. .filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
  232. .first()
  233. is not None
  234. )
  235. def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
  236. """
  237. check if tenant oauth params is enabled
  238. """
  239. with Session(db.engine).no_autoflush as session:
  240. return (
  241. session.query(DatasourceOauthTenantParamConfig)
  242. .filter_by(
  243. tenant_id=tenant_id,
  244. provider=datasource_provider_id.provider_name,
  245. plugin_id=datasource_provider_id.plugin_id,
  246. enabled=True,
  247. )
  248. .count()
  249. > 0
  250. )
  251. def get_tenant_oauth_client(
  252. self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
  253. ) -> dict[str, Any] | None:
  254. """
  255. get tenant oauth client
  256. """
  257. with Session(db.engine).no_autoflush as session:
  258. tenant_oauth_client_params = (
  259. session.query(DatasourceOauthTenantParamConfig)
  260. .filter_by(
  261. tenant_id=tenant_id,
  262. provider=datasource_provider_id.provider_name,
  263. plugin_id=datasource_provider_id.plugin_id,
  264. )
  265. .first()
  266. )
  267. if tenant_oauth_client_params:
  268. encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
  269. if mask:
  270. return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
  271. else:
  272. return encrypter.decrypt(tenant_oauth_client_params.client_params)
  273. return None
  274. def get_oauth_encrypter(
  275. self, tenant_id: str, datasource_provider_id: DatasourceProviderID
  276. ) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
  277. """
  278. get oauth encrypter
  279. """
  280. datasource_provider = self.provider_manager.fetch_datasource_provider(
  281. tenant_id=tenant_id, provider_id=str(datasource_provider_id)
  282. )
  283. if not datasource_provider.declaration.oauth_schema:
  284. raise ValueError("Datasource provider oauth schema not found")
  285. client_schema = datasource_provider.declaration.oauth_schema.client_schema
  286. return create_provider_encrypter(
  287. tenant_id=tenant_id,
  288. config=[x.to_basic_provider_config() for x in client_schema],
  289. cache=NoOpProviderCredentialCache(),
  290. )
  291. def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
  292. """
  293. get oauth client
  294. """
  295. provider = datasource_provider_id.provider_name
  296. plugin_id = datasource_provider_id.plugin_id
  297. with Session(db.engine).no_autoflush as session:
  298. # get tenant oauth client params
  299. tenant_oauth_client_params = (
  300. session.query(DatasourceOauthTenantParamConfig)
  301. .filter_by(
  302. tenant_id=tenant_id,
  303. provider=provider,
  304. plugin_id=plugin_id,
  305. enabled=True,
  306. )
  307. .first()
  308. )
  309. if tenant_oauth_client_params:
  310. encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
  311. return encrypter.decrypt(tenant_oauth_client_params.client_params)
  312. # fallback to system oauth client params
  313. oauth_client_params = (
  314. session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
  315. )
  316. if oauth_client_params:
  317. return oauth_client_params.system_credentials
  318. raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
  319. @staticmethod
  320. def generate_next_datasource_provider_name(
  321. session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
  322. ) -> str:
  323. db_providers = (
  324. session.query(DatasourceProvider)
  325. .filter_by(
  326. tenant_id=tenant_id,
  327. provider=provider_id.provider_name,
  328. plugin_id=provider_id.plugin_id,
  329. )
  330. .all()
  331. )
  332. return generate_incremental_name(
  333. [provider.name for provider in db_providers],
  334. f"{credential_type.get_name()}",
  335. )
  336. def reauthorize_datasource_oauth_provider(
  337. self,
  338. name: str | None,
  339. tenant_id: str,
  340. provider_id: DatasourceProviderID,
  341. avatar_url: str | None,
  342. credentials: dict,
  343. credential_id: str,
  344. ) -> None:
  345. """
  346. update datasource oauth provider
  347. """
  348. with Session(db.engine) as session:
  349. lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
  350. with redis_client.lock(lock, timeout=20):
  351. target_provider = (
  352. session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
  353. )
  354. if target_provider is None:
  355. raise ValueError("provider not found")
  356. db_provider_name = name
  357. if not db_provider_name:
  358. db_provider_name = target_provider.name
  359. else:
  360. name_conflict = (
  361. session.query(DatasourceProvider)
  362. .filter_by(
  363. tenant_id=tenant_id,
  364. name=db_provider_name,
  365. provider=provider_id.provider_name,
  366. plugin_id=provider_id.plugin_id,
  367. auth_type=CredentialType.OAUTH2.value,
  368. )
  369. .count()
  370. )
  371. if name_conflict > 0:
  372. db_provider_name = generate_incremental_name(
  373. [
  374. provider.name
  375. for provider in session.query(DatasourceProvider).filter_by(
  376. tenant_id=tenant_id,
  377. provider=provider_id.provider_name,
  378. plugin_id=provider_id.plugin_id,
  379. )
  380. ],
  381. db_provider_name,
  382. )
  383. provider_credential_secret_variables = self.extract_secret_variables(
  384. tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2
  385. )
  386. for key, value in credentials.items():
  387. if key in provider_credential_secret_variables:
  388. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  389. target_provider.encrypted_credentials = credentials
  390. target_provider.avatar_url = avatar_url or target_provider.avatar_url
  391. session.commit()
  392. def add_datasource_oauth_provider(
  393. self,
  394. name: str | None,
  395. tenant_id: str,
  396. provider_id: DatasourceProviderID,
  397. avatar_url: str | None,
  398. credentials: dict,
  399. ) -> None:
  400. """
  401. add datasource oauth provider
  402. """
  403. credential_type = CredentialType.OAUTH2
  404. with Session(db.engine) as session:
  405. lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
  406. with redis_client.lock(lock, timeout=20):
  407. db_provider_name = name
  408. if not db_provider_name:
  409. db_provider_name = self.generate_next_datasource_provider_name(
  410. session=session,
  411. tenant_id=tenant_id,
  412. provider_id=provider_id,
  413. credential_type=credential_type,
  414. )
  415. else:
  416. if (
  417. session.query(DatasourceProvider)
  418. .filter_by(
  419. tenant_id=tenant_id,
  420. name=db_provider_name,
  421. provider=provider_id.provider_name,
  422. plugin_id=provider_id.plugin_id,
  423. auth_type=credential_type.value,
  424. )
  425. .count()
  426. > 0
  427. ):
  428. db_provider_name = generate_incremental_name(
  429. [
  430. provider.name
  431. for provider in session.query(DatasourceProvider).filter_by(
  432. tenant_id=tenant_id,
  433. provider=provider_id.provider_name,
  434. plugin_id=provider_id.plugin_id,
  435. )
  436. ],
  437. db_provider_name,
  438. )
  439. provider_credential_secret_variables = self.extract_secret_variables(
  440. tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type
  441. )
  442. for key, value in credentials.items():
  443. if key in provider_credential_secret_variables:
  444. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  445. datasource_provider = DatasourceProvider(
  446. tenant_id=tenant_id,
  447. name=db_provider_name,
  448. provider=provider_id.provider_name,
  449. plugin_id=provider_id.plugin_id,
  450. auth_type=credential_type.value,
  451. encrypted_credentials=credentials,
  452. avatar_url=avatar_url or "default",
  453. )
  454. session.add(datasource_provider)
  455. session.commit()
  456. def add_datasource_api_key_provider(
  457. self,
  458. name: str | None,
  459. tenant_id: str,
  460. provider_id: DatasourceProviderID,
  461. credentials: dict,
  462. ) -> None:
  463. """
  464. validate datasource provider credentials.
  465. :param tenant_id:
  466. :param provider:
  467. :param credentials:
  468. """
  469. provider_name = provider_id.provider_name
  470. plugin_id = provider_id.plugin_id
  471. with Session(db.engine) as session:
  472. lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
  473. with redis_client.lock(lock, timeout=20):
  474. db_provider_name = name or self.generate_next_datasource_provider_name(
  475. session=session,
  476. tenant_id=tenant_id,
  477. provider_id=provider_id,
  478. credential_type=CredentialType.API_KEY,
  479. )
  480. # check name is exist
  481. if (
  482. session.query(DatasourceProvider)
  483. .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
  484. .count()
  485. > 0
  486. ):
  487. raise ValueError("Authorization name is already exists")
  488. try:
  489. self.provider_manager.validate_provider_credentials(
  490. tenant_id=tenant_id,
  491. user_id=current_user.id,
  492. provider=provider_name,
  493. plugin_id=plugin_id,
  494. credentials=credentials,
  495. )
  496. except Exception as e:
  497. raise ValueError(f"Failed to validate credentials: {str(e)}")
  498. provider_credential_secret_variables = self.extract_secret_variables(
  499. tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY
  500. )
  501. for key, value in credentials.items():
  502. if key in provider_credential_secret_variables:
  503. # if send [__HIDDEN__] in secret input, it will be same as original value
  504. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  505. datasource_provider = DatasourceProvider(
  506. tenant_id=tenant_id,
  507. name=db_provider_name,
  508. provider=provider_name,
  509. plugin_id=plugin_id,
  510. auth_type=CredentialType.API_KEY.value,
  511. encrypted_credentials=credentials,
  512. )
  513. db.session.add(datasource_provider)
  514. db.session.commit()
  515. def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
  516. """
  517. Extract secret input form variables.
  518. :param credential_form_schemas:
  519. :return:
  520. """
  521. datasource_provider = self.provider_manager.fetch_datasource_provider(
  522. tenant_id=tenant_id, provider_id=provider_id
  523. )
  524. credential_form_schemas = []
  525. if credential_type == CredentialType.API_KEY:
  526. credential_form_schemas = list(datasource_provider.declaration.credentials_schema)
  527. elif credential_type == CredentialType.OAUTH2:
  528. if not datasource_provider.declaration.oauth_schema:
  529. raise ValueError("Datasource provider oauth schema not found")
  530. credential_form_schemas = list(datasource_provider.declaration.oauth_schema.credentials_schema)
  531. else:
  532. raise ValueError(f"Invalid credential type: {credential_type}")
  533. secret_input_form_variables = []
  534. for credential_form_schema in credential_form_schemas:
  535. if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
  536. secret_input_form_variables.append(credential_form_schema.name)
  537. return secret_input_form_variables
  538. def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  539. """
  540. get datasource credentials.
  541. :param tenant_id: workspace id
  542. :param provider_id: provider id
  543. :return:
  544. """
  545. # Get all provider configurations of the current workspace
  546. datasource_providers: list[DatasourceProvider] = (
  547. db.session.query(DatasourceProvider)
  548. .filter(
  549. DatasourceProvider.tenant_id == tenant_id,
  550. DatasourceProvider.provider == provider,
  551. DatasourceProvider.plugin_id == plugin_id,
  552. )
  553. .all()
  554. )
  555. if not datasource_providers:
  556. return []
  557. copy_credentials_list = []
  558. default_provider = (
  559. db.session.query(DatasourceProvider.id)
  560. .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
  561. .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
  562. .first()
  563. )
  564. default_provider_id = default_provider.id if default_provider else None
  565. for datasource_provider in datasource_providers:
  566. encrypted_credentials = datasource_provider.encrypted_credentials
  567. # Get provider credential secret variables
  568. credential_secret_variables = self.extract_secret_variables(
  569. tenant_id=tenant_id,
  570. provider_id=f"{plugin_id}/{provider}",
  571. credential_type=CredentialType.of(datasource_provider.auth_type),
  572. )
  573. # Obfuscate provider credentials
  574. copy_credentials = encrypted_credentials.copy()
  575. for key, value in copy_credentials.items():
  576. if key in credential_secret_variables:
  577. copy_credentials[key] = encrypter.obfuscated_token(value)
  578. copy_credentials_list.append(
  579. {
  580. "credential": copy_credentials,
  581. "type": datasource_provider.auth_type,
  582. "name": datasource_provider.name,
  583. "avatar_url": datasource_provider.avatar_url,
  584. "id": datasource_provider.id,
  585. "is_default": default_provider_id and datasource_provider.id == default_provider_id,
  586. }
  587. )
  588. return copy_credentials_list
  589. def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
  590. """
  591. get datasource credentials.
  592. :return:
  593. """
  594. # get all plugin providers
  595. manager = PluginDatasourceManager()
  596. datasources = manager.fetch_installed_datasource_providers(tenant_id)
  597. datasource_credentials = []
  598. for datasource in datasources:
  599. datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
  600. credentials = self.get_datasource_credentials(
  601. tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
  602. )
  603. redirect_uri = (
  604. f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
  605. )
  606. datasource_credentials.append(
  607. {
  608. "provider": datasource.provider,
  609. "plugin_id": datasource.plugin_id,
  610. "plugin_unique_identifier": datasource.plugin_unique_identifier,
  611. "icon": datasource.declaration.identity.icon,
  612. "name": datasource.declaration.identity.name.split("/")[-1],
  613. "label": datasource.declaration.identity.label.model_dump(),
  614. "description": datasource.declaration.identity.description.model_dump(),
  615. "author": datasource.declaration.identity.author,
  616. "credentials_list": credentials,
  617. "credential_schema": [
  618. credential.model_dump() for credential in datasource.declaration.credentials_schema
  619. ],
  620. "oauth_schema": {
  621. "client_schema": [
  622. client_schema.model_dump()
  623. for client_schema in datasource.declaration.oauth_schema.client_schema
  624. ],
  625. "credentials_schema": [
  626. credential_schema.model_dump()
  627. for credential_schema in datasource.declaration.oauth_schema.credentials_schema
  628. ],
  629. "oauth_custom_client_params": self.get_tenant_oauth_client(
  630. tenant_id, datasource_provider_id, mask=True
  631. ),
  632. "is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
  633. tenant_id, datasource_provider_id
  634. ),
  635. "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
  636. "redirect_uri": redirect_uri,
  637. }
  638. if datasource.declaration.oauth_schema
  639. else None,
  640. }
  641. )
  642. return datasource_credentials
  643. def get_hard_code_datasource_credentials(self, tenant_id: str) -> list[dict]:
  644. """
  645. get hard code datasource credentials.
  646. :return:
  647. """
  648. # get all plugin providers
  649. manager = PluginDatasourceManager()
  650. datasources = manager.fetch_installed_datasource_providers(tenant_id)
  651. datasource_credentials = []
  652. for datasource in datasources:
  653. if datasource.plugin_id in [
  654. "langgenius/firecrawl_datasource",
  655. "langgenius/notion_datasource",
  656. "langgenius/jina_datasource",
  657. ]:
  658. datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
  659. credentials = self.get_datasource_credentials(
  660. tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
  661. )
  662. redirect_uri = (
  663. f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
  664. )
  665. datasource_credentials.append(
  666. {
  667. "provider": datasource.provider,
  668. "plugin_id": datasource.plugin_id,
  669. "plugin_unique_identifier": datasource.plugin_unique_identifier,
  670. "icon": datasource.declaration.identity.icon,
  671. "name": datasource.declaration.identity.name.split("/")[-1],
  672. "label": datasource.declaration.identity.label.model_dump(),
  673. "description": datasource.declaration.identity.description.model_dump(),
  674. "author": datasource.declaration.identity.author,
  675. "credentials_list": credentials,
  676. "credential_schema": [
  677. credential.model_dump() for credential in datasource.declaration.credentials_schema
  678. ],
  679. "oauth_schema": {
  680. "client_schema": [
  681. client_schema.model_dump()
  682. for client_schema in datasource.declaration.oauth_schema.client_schema
  683. ],
  684. "credentials_schema": [
  685. credential_schema.model_dump()
  686. for credential_schema in datasource.declaration.oauth_schema.credentials_schema
  687. ],
  688. "oauth_custom_client_params": self.get_tenant_oauth_client(
  689. tenant_id, datasource_provider_id, mask=True
  690. ),
  691. "is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
  692. tenant_id, datasource_provider_id
  693. ),
  694. "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
  695. "redirect_uri": redirect_uri,
  696. }
  697. if datasource.declaration.oauth_schema
  698. else None,
  699. }
  700. )
  701. return datasource_credentials
  702. def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  703. """
  704. get datasource credentials.
  705. :param tenant_id: workspace id
  706. :param provider_id: provider id
  707. :return:
  708. """
  709. # Get all provider configurations of the current workspace
  710. datasource_providers: list[DatasourceProvider] = (
  711. db.session.query(DatasourceProvider)
  712. .filter(
  713. DatasourceProvider.tenant_id == tenant_id,
  714. DatasourceProvider.provider == provider,
  715. DatasourceProvider.plugin_id == plugin_id,
  716. )
  717. .all()
  718. )
  719. if not datasource_providers:
  720. return []
  721. copy_credentials_list = []
  722. for datasource_provider in datasource_providers:
  723. encrypted_credentials = datasource_provider.encrypted_credentials
  724. # Get provider credential secret variables
  725. credential_secret_variables = self.extract_secret_variables(
  726. tenant_id=tenant_id,
  727. provider_id=f"{plugin_id}/{provider}",
  728. credential_type=CredentialType.of(datasource_provider.auth_type),
  729. )
  730. # Obfuscate provider credentials
  731. copy_credentials = encrypted_credentials.copy()
  732. for key, value in copy_credentials.items():
  733. if key in credential_secret_variables:
  734. copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
  735. copy_credentials_list.append(
  736. {
  737. "credentials": copy_credentials,
  738. "type": datasource_provider.auth_type,
  739. }
  740. )
  741. return copy_credentials_list
  742. def update_datasource_credentials(
  743. self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
  744. ) -> None:
  745. """
  746. update datasource credentials.
  747. """
  748. with Session(db.engine) as session:
  749. datasource_provider = (
  750. session.query(DatasourceProvider)
  751. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  752. .first()
  753. )
  754. if not datasource_provider:
  755. raise ValueError("Datasource provider not found")
  756. # update name
  757. if name and name != datasource_provider.name:
  758. if (
  759. session.query(DatasourceProvider)
  760. .filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
  761. .count()
  762. > 0
  763. ):
  764. raise ValueError("Authorization name is already exists")
  765. datasource_provider.name = name
  766. # update credentials
  767. if credentials:
  768. secret_variables = self.extract_secret_variables(
  769. tenant_id=tenant_id,
  770. provider_id=f"{plugin_id}/{provider}",
  771. credential_type=CredentialType.of(datasource_provider.auth_type),
  772. )
  773. original_credentials = {
  774. key: value if key not in secret_variables else encrypter.decrypt_token(tenant_id, value)
  775. for key, value in datasource_provider.encrypted_credentials.items()
  776. }
  777. new_credentials = {
  778. key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
  779. for key, value in credentials.items()
  780. }
  781. try:
  782. self.provider_manager.validate_provider_credentials(
  783. tenant_id=tenant_id,
  784. user_id=current_user.id,
  785. provider=provider,
  786. plugin_id=plugin_id,
  787. credentials=new_credentials,
  788. )
  789. except Exception as e:
  790. raise ValueError(f"Failed to validate credentials: {str(e)}")
  791. encrypted_credentials = {}
  792. for key, value in new_credentials.items():
  793. if key in secret_variables:
  794. encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
  795. else:
  796. encrypted_credentials[key] = value
  797. datasource_provider.encrypted_credentials = encrypted_credentials
  798. session.commit()
  799. def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
  800. """
  801. remove datasource credentials.
  802. :param tenant_id: workspace id
  803. :param provider: provider name
  804. :param plugin_id: plugin id
  805. :return:
  806. """
  807. datasource_provider = (
  808. db.session.query(DatasourceProvider)
  809. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  810. .first()
  811. )
  812. if datasource_provider:
  813. db.session.delete(datasource_provider)
  814. db.session.commit()