You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_oauth.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. from unittest.mock import MagicMock, patch
  2. import pytest
  3. from flask import Flask
  4. from controllers.console.auth.oauth import (
  5. OAuthCallback,
  6. OAuthLogin,
  7. _generate_account,
  8. _get_account_by_openid_or_email,
  9. get_oauth_providers,
  10. )
  11. from libs.oauth import OAuthUserInfo
  12. from models.account import AccountStatus
  13. from services.errors.account import AccountRegisterError
  14. class TestGetOAuthProviders:
  15. @pytest.fixture
  16. def app(self):
  17. app = Flask(__name__)
  18. app.config["TESTING"] = True
  19. return app
  20. @pytest.mark.parametrize(
  21. ("github_config", "google_config", "expected_github", "expected_google"),
  22. [
  23. # Both providers configured
  24. (
  25. {"id": "github_id", "secret": "github_secret"},
  26. {"id": "google_id", "secret": "google_secret"},
  27. True,
  28. True,
  29. ),
  30. # Only GitHub configured
  31. ({"id": "github_id", "secret": "github_secret"}, {"id": None, "secret": None}, True, False),
  32. # Only Google configured
  33. ({"id": None, "secret": None}, {"id": "google_id", "secret": "google_secret"}, False, True),
  34. # No providers configured
  35. ({"id": None, "secret": None}, {"id": None, "secret": None}, False, False),
  36. ],
  37. )
  38. @patch("controllers.console.auth.oauth.dify_config")
  39. def test_should_configure_oauth_providers_correctly(
  40. self, mock_config, app, github_config, google_config, expected_github, expected_google
  41. ):
  42. mock_config.GITHUB_CLIENT_ID = github_config["id"]
  43. mock_config.GITHUB_CLIENT_SECRET = github_config["secret"]
  44. mock_config.GOOGLE_CLIENT_ID = google_config["id"]
  45. mock_config.GOOGLE_CLIENT_SECRET = google_config["secret"]
  46. mock_config.CONSOLE_API_URL = "http://localhost"
  47. with app.app_context():
  48. providers = get_oauth_providers()
  49. assert (providers["github"] is not None) == expected_github
  50. assert (providers["google"] is not None) == expected_google
  51. class TestOAuthLogin:
  52. @pytest.fixture
  53. def resource(self):
  54. return OAuthLogin()
  55. @pytest.fixture
  56. def app(self):
  57. app = Flask(__name__)
  58. app.config["TESTING"] = True
  59. return app
  60. @pytest.fixture
  61. def mock_oauth_provider(self):
  62. provider = MagicMock()
  63. provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?..."
  64. return provider
  65. @pytest.mark.parametrize(
  66. ("invite_token", "expected_token"),
  67. [
  68. (None, None),
  69. ("test_invite_token", "test_invite_token"),
  70. ("", None),
  71. ],
  72. )
  73. @patch("controllers.console.auth.oauth.get_oauth_providers")
  74. @patch("controllers.console.auth.oauth.redirect")
  75. def test_should_handle_oauth_login_with_various_tokens(
  76. self,
  77. mock_redirect,
  78. mock_get_providers,
  79. resource,
  80. app,
  81. mock_oauth_provider,
  82. invite_token,
  83. expected_token,
  84. ):
  85. mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
  86. query_string = f"invite_token={invite_token}" if invite_token else ""
  87. with app.test_request_context(f"/auth/oauth/github?{query_string}"):
  88. resource.get("github")
  89. mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token)
  90. mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
  91. @pytest.mark.parametrize(
  92. ("provider", "expected_error"),
  93. [
  94. ("invalid_provider", "Invalid provider"),
  95. ("github", "Invalid provider"), # When GitHub is not configured
  96. ("google", "Invalid provider"), # When Google is not configured
  97. ],
  98. )
  99. @patch("controllers.console.auth.oauth.get_oauth_providers")
  100. def test_should_return_error_for_invalid_providers(
  101. self, mock_get_providers, resource, app, provider, expected_error
  102. ):
  103. mock_get_providers.return_value = {"github": None, "google": None}
  104. with app.test_request_context(f"/auth/oauth/{provider}"):
  105. response, status_code = resource.get(provider)
  106. assert status_code == 400
  107. assert response["error"] == expected_error
  108. class TestOAuthCallback:
  109. @pytest.fixture
  110. def resource(self):
  111. return OAuthCallback()
  112. @pytest.fixture
  113. def app(self):
  114. app = Flask(__name__)
  115. app.config["TESTING"] = True
  116. return app
  117. @pytest.fixture
  118. def oauth_setup(self):
  119. """Common OAuth setup for callback tests"""
  120. oauth_provider = MagicMock()
  121. oauth_provider.get_access_token.return_value = "access_token"
  122. oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com")
  123. account = MagicMock()
  124. account.status = AccountStatus.ACTIVE.value
  125. token_pair = MagicMock()
  126. token_pair.access_token = "jwt_access_token"
  127. token_pair.refresh_token = "jwt_refresh_token"
  128. return {"provider": oauth_provider, "account": account, "token_pair": token_pair}
  129. @patch("controllers.console.auth.oauth.dify_config")
  130. @patch("controllers.console.auth.oauth.get_oauth_providers")
  131. @patch("controllers.console.auth.oauth._generate_account")
  132. @patch("controllers.console.auth.oauth.AccountService")
  133. @patch("controllers.console.auth.oauth.TenantService")
  134. @patch("controllers.console.auth.oauth.redirect")
  135. def test_should_handle_successful_oauth_callback(
  136. self,
  137. mock_redirect,
  138. mock_tenant_service,
  139. mock_account_service,
  140. mock_generate_account,
  141. mock_get_providers,
  142. mock_config,
  143. resource,
  144. app,
  145. oauth_setup,
  146. ):
  147. mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
  148. mock_get_providers.return_value = {"github": oauth_setup["provider"]}
  149. mock_generate_account.return_value = oauth_setup["account"]
  150. mock_account_service.login.return_value = oauth_setup["token_pair"]
  151. with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
  152. resource.get("github")
  153. oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
  154. oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
  155. mock_redirect.assert_called_once_with(
  156. "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
  157. )
  158. @pytest.mark.parametrize(
  159. ("exception", "expected_error"),
  160. [
  161. (Exception("OAuth error"), "OAuth process failed"),
  162. (ValueError("Invalid token"), "OAuth process failed"),
  163. (KeyError("Missing key"), "OAuth process failed"),
  164. ],
  165. )
  166. @patch("controllers.console.auth.oauth.db")
  167. @patch("controllers.console.auth.oauth.get_oauth_providers")
  168. def test_should_handle_oauth_exceptions(
  169. self, mock_get_providers, mock_db, resource, app, exception, expected_error
  170. ):
  171. # Mock database session
  172. mock_db.session = MagicMock()
  173. mock_db.session.rollback = MagicMock()
  174. # Import the real requests module to create a proper exception
  175. import httpx
  176. request_exception = httpx.RequestError("OAuth error")
  177. request_exception.response = MagicMock()
  178. request_exception.response.text = str(exception)
  179. mock_oauth_provider = MagicMock()
  180. mock_oauth_provider.get_access_token.side_effect = request_exception
  181. mock_get_providers.return_value = {"github": mock_oauth_provider}
  182. with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
  183. response, status_code = resource.get("github")
  184. assert status_code == 400
  185. assert response["error"] == expected_error
  186. @pytest.mark.parametrize(
  187. ("account_status", "expected_redirect"),
  188. [
  189. (AccountStatus.BANNED.value, "http://localhost:3000/signin?message=Account is banned."),
  190. # CLOSED status: Currently NOT handled, will proceed to login (security issue)
  191. # This documents actual behavior. See test_defensive_check_for_closed_account_status for details
  192. (
  193. AccountStatus.CLOSED.value,
  194. "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token",
  195. ),
  196. ],
  197. )
  198. @patch("controllers.console.auth.oauth.AccountService")
  199. @patch("controllers.console.auth.oauth.TenantService")
  200. @patch("controllers.console.auth.oauth.db")
  201. @patch("controllers.console.auth.oauth.dify_config")
  202. @patch("controllers.console.auth.oauth.get_oauth_providers")
  203. @patch("controllers.console.auth.oauth._generate_account")
  204. @patch("controllers.console.auth.oauth.redirect")
  205. def test_should_redirect_based_on_account_status(
  206. self,
  207. mock_redirect,
  208. mock_generate_account,
  209. mock_get_providers,
  210. mock_config,
  211. mock_db,
  212. mock_tenant_service,
  213. mock_account_service,
  214. resource,
  215. app,
  216. oauth_setup,
  217. account_status,
  218. expected_redirect,
  219. ):
  220. # Mock database session
  221. mock_db.session = MagicMock()
  222. mock_db.session.rollback = MagicMock()
  223. mock_db.session.commit = MagicMock()
  224. mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
  225. mock_get_providers.return_value = {"github": oauth_setup["provider"]}
  226. account = MagicMock()
  227. account.status = account_status
  228. account.id = "123"
  229. mock_generate_account.return_value = account
  230. # Mock login for CLOSED status
  231. mock_token_pair = MagicMock()
  232. mock_token_pair.access_token = "jwt_access_token"
  233. mock_token_pair.refresh_token = "jwt_refresh_token"
  234. mock_account_service.login.return_value = mock_token_pair
  235. with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
  236. resource.get("github")
  237. mock_redirect.assert_called_once_with(expected_redirect)
  238. @patch("controllers.console.auth.oauth.dify_config")
  239. @patch("controllers.console.auth.oauth.get_oauth_providers")
  240. @patch("controllers.console.auth.oauth._generate_account")
  241. @patch("controllers.console.auth.oauth.db")
  242. @patch("controllers.console.auth.oauth.TenantService")
  243. @patch("controllers.console.auth.oauth.AccountService")
  244. def test_should_activate_pending_account(
  245. self,
  246. mock_account_service,
  247. mock_tenant_service,
  248. mock_db,
  249. mock_generate_account,
  250. mock_get_providers,
  251. mock_config,
  252. resource,
  253. app,
  254. oauth_setup,
  255. ):
  256. mock_get_providers.return_value = {"github": oauth_setup["provider"]}
  257. mock_account = MagicMock()
  258. mock_account.status = AccountStatus.PENDING.value
  259. mock_generate_account.return_value = mock_account
  260. with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
  261. resource.get("github")
  262. assert mock_account.status == AccountStatus.ACTIVE.value
  263. assert mock_account.initialized_at is not None
  264. mock_db.session.commit.assert_called_once()
  265. @patch("controllers.console.auth.oauth.dify_config")
  266. @patch("controllers.console.auth.oauth.get_oauth_providers")
  267. @patch("controllers.console.auth.oauth._generate_account")
  268. @patch("controllers.console.auth.oauth.db")
  269. @patch("controllers.console.auth.oauth.TenantService")
  270. @patch("controllers.console.auth.oauth.AccountService")
  271. @patch("controllers.console.auth.oauth.redirect")
  272. def test_defensive_check_for_closed_account_status(
  273. self,
  274. mock_redirect,
  275. mock_account_service,
  276. mock_tenant_service,
  277. mock_db,
  278. mock_generate_account,
  279. mock_get_providers,
  280. mock_config,
  281. resource,
  282. app,
  283. oauth_setup,
  284. ):
  285. """Defensive test for CLOSED account status handling in OAuth callback.
  286. This is a defensive test documenting expected security behavior for CLOSED accounts.
  287. Current behavior: CLOSED status is NOT checked, allowing closed accounts to login.
  288. Expected behavior: CLOSED accounts should be rejected like BANNED accounts.
  289. Context:
  290. - AccountStatus.CLOSED is defined in the enum but never used in production
  291. - The close_account() method exists but is never called
  292. - Account deletion uses external service instead of status change
  293. - All authentication services (OAuth, password, email) don't check CLOSED status
  294. TODO: If CLOSED status is implemented in the future:
  295. 1. Update OAuth callback to check for CLOSED status
  296. 2. Add similar checks to all authentication services for consistency
  297. 3. Update this test to verify the rejection behavior
  298. Security consideration: Until properly implemented, CLOSED status provides no protection.
  299. """
  300. # Setup
  301. mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
  302. mock_get_providers.return_value = {"github": oauth_setup["provider"]}
  303. # Create account with CLOSED status
  304. closed_account = MagicMock()
  305. closed_account.status = AccountStatus.CLOSED.value
  306. closed_account.id = "123"
  307. closed_account.name = "Closed Account"
  308. mock_generate_account.return_value = closed_account
  309. # Mock successful login (current behavior)
  310. mock_token_pair = MagicMock()
  311. mock_token_pair.access_token = "jwt_access_token"
  312. mock_token_pair.refresh_token = "jwt_refresh_token"
  313. mock_account_service.login.return_value = mock_token_pair
  314. # Execute OAuth callback
  315. with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
  316. resource.get("github")
  317. # Verify current behavior: login succeeds (this is NOT ideal)
  318. mock_redirect.assert_called_once_with(
  319. "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
  320. )
  321. mock_account_service.login.assert_called_once()
  322. # Document expected behavior in comments:
  323. # Expected: mock_redirect.assert_called_once_with(
  324. # "http://localhost:3000/signin?message=Account is closed."
  325. # )
  326. # Expected: mock_account_service.login.assert_not_called()
  327. class TestAccountGeneration:
  328. @pytest.fixture
  329. def user_info(self):
  330. return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
  331. @pytest.fixture
  332. def mock_account(self):
  333. account = MagicMock()
  334. account.name = "Test User"
  335. return account
  336. @patch("controllers.console.auth.oauth.db")
  337. @patch("controllers.console.auth.oauth.Account")
  338. @patch("controllers.console.auth.oauth.Session")
  339. @patch("controllers.console.auth.oauth.select")
  340. def test_should_get_account_by_openid_or_email(
  341. self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
  342. ):
  343. # Mock db.engine for Session creation
  344. mock_db.engine = MagicMock()
  345. # Test OpenID found
  346. mock_account_model.get_by_openid.return_value = mock_account
  347. result = _get_account_by_openid_or_email("github", user_info)
  348. assert result == mock_account
  349. mock_account_model.get_by_openid.assert_called_once_with("github", "123")
  350. # Test fallback to email
  351. mock_account_model.get_by_openid.return_value = None
  352. mock_session_instance = MagicMock()
  353. mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
  354. mock_session.return_value.__enter__.return_value = mock_session_instance
  355. result = _get_account_by_openid_or_email("github", user_info)
  356. assert result == mock_account
  357. @pytest.mark.parametrize(
  358. ("allow_register", "existing_account", "should_create"),
  359. [
  360. (True, None, True), # New account creation allowed
  361. (True, "existing", False), # Existing account
  362. (False, None, False), # Registration not allowed
  363. ],
  364. )
  365. @patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
  366. @patch("controllers.console.auth.oauth.FeatureService")
  367. @patch("controllers.console.auth.oauth.RegisterService")
  368. @patch("controllers.console.auth.oauth.AccountService")
  369. @patch("controllers.console.auth.oauth.TenantService")
  370. @patch("controllers.console.auth.oauth.db")
  371. def test_should_handle_account_generation_scenarios(
  372. self,
  373. mock_db,
  374. mock_tenant_service,
  375. mock_account_service,
  376. mock_register_service,
  377. mock_feature_service,
  378. mock_get_account,
  379. app,
  380. user_info,
  381. mock_account,
  382. allow_register,
  383. existing_account,
  384. should_create,
  385. ):
  386. mock_get_account.return_value = mock_account if existing_account else None
  387. mock_feature_service.get_system_features.return_value.is_allow_register = allow_register
  388. mock_register_service.register.return_value = mock_account
  389. with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
  390. if not allow_register and not existing_account:
  391. with pytest.raises(AccountRegisterError):
  392. _generate_account("github", user_info)
  393. else:
  394. result = _generate_account("github", user_info)
  395. assert result == mock_account
  396. if should_create:
  397. mock_register_service.register.assert_called_once_with(
  398. email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
  399. )
  400. @patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
  401. @patch("controllers.console.auth.oauth.TenantService")
  402. @patch("controllers.console.auth.oauth.FeatureService")
  403. @patch("controllers.console.auth.oauth.AccountService")
  404. @patch("controllers.console.auth.oauth.tenant_was_created")
  405. def test_should_create_workspace_for_account_without_tenant(
  406. self,
  407. mock_event,
  408. mock_account_service,
  409. mock_feature_service,
  410. mock_tenant_service,
  411. mock_get_account,
  412. app,
  413. user_info,
  414. mock_account,
  415. ):
  416. mock_get_account.return_value = mock_account
  417. mock_tenant_service.get_join_tenants.return_value = []
  418. mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
  419. mock_new_tenant = MagicMock()
  420. mock_tenant_service.create_tenant.return_value = mock_new_tenant
  421. with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
  422. result = _generate_account("github", user_info)
  423. assert result == mock_account
  424. mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
  425. mock_tenant_service.create_tenant_member.assert_called_once_with(
  426. mock_new_tenant, mock_account, role="owner"
  427. )
  428. mock_event.send.assert_called_once_with(mock_new_tenant)