Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

test_wraps.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. from unittest.mock import MagicMock, patch
  2. import pytest
  3. from flask import Flask
  4. from flask_login import LoginManager, UserMixin
  5. from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
  6. from controllers.console.workspace.error import AccountNotInitializedError
  7. from controllers.console.wraps import (
  8. account_initialization_required,
  9. cloud_edition_billing_rate_limit_check,
  10. cloud_edition_billing_resource_check,
  11. enterprise_license_required,
  12. only_edition_cloud,
  13. only_edition_enterprise,
  14. only_edition_self_hosted,
  15. setup_required,
  16. )
  17. from models.account import AccountStatus
  18. from services.feature_service import LicenseStatus
  19. class MockUser(UserMixin):
  20. """Simple User class for testing."""
  21. def __init__(self, user_id: str):
  22. self.id = user_id
  23. self.current_tenant_id = "tenant123"
  24. def get_id(self) -> str:
  25. return self.id
  26. def create_app_with_login():
  27. """Create a Flask app with LoginManager configured."""
  28. app = Flask(__name__)
  29. app.config["SECRET_KEY"] = "test-secret-key"
  30. login_manager = LoginManager()
  31. login_manager.init_app(app)
  32. @login_manager.user_loader
  33. def load_user(user_id: str):
  34. return MockUser(user_id)
  35. return app
  36. class TestAccountInitialization:
  37. """Test account initialization decorator"""
  38. def test_should_allow_initialized_account(self):
  39. """Test that initialized accounts can access protected views"""
  40. # Arrange
  41. mock_user = MagicMock()
  42. mock_user.status = AccountStatus.ACTIVE
  43. @account_initialization_required
  44. def protected_view():
  45. return "success"
  46. # Act
  47. with patch("controllers.console.wraps.current_user", mock_user):
  48. result = protected_view()
  49. # Assert
  50. assert result == "success"
  51. def test_should_reject_uninitialized_account(self):
  52. """Test that uninitialized accounts raise AccountNotInitializedError"""
  53. # Arrange
  54. mock_user = MagicMock()
  55. mock_user.status = AccountStatus.UNINITIALIZED
  56. @account_initialization_required
  57. def protected_view():
  58. return "success"
  59. # Act & Assert
  60. with patch("controllers.console.wraps.current_user", mock_user):
  61. with pytest.raises(AccountNotInitializedError):
  62. protected_view()
  63. class TestEditionChecks:
  64. """Test edition-specific decorators"""
  65. def test_only_edition_cloud_allows_cloud_edition(self):
  66. """Test cloud edition decorator allows CLOUD edition"""
  67. # Arrange
  68. @only_edition_cloud
  69. def cloud_view():
  70. return "cloud_success"
  71. # Act
  72. with patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"):
  73. result = cloud_view()
  74. # Assert
  75. assert result == "cloud_success"
  76. def test_only_edition_cloud_rejects_other_editions(self):
  77. """Test cloud edition decorator rejects non-CLOUD editions"""
  78. # Arrange
  79. app = Flask(__name__)
  80. @only_edition_cloud
  81. def cloud_view():
  82. return "cloud_success"
  83. # Act & Assert
  84. with app.test_request_context():
  85. with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
  86. with pytest.raises(Exception) as exc_info:
  87. cloud_view()
  88. assert exc_info.value.code == 404
  89. def test_only_edition_enterprise_allows_when_enabled(self):
  90. """Test enterprise edition decorator allows when ENTERPRISE_ENABLED is True"""
  91. # Arrange
  92. @only_edition_enterprise
  93. def enterprise_view():
  94. return "enterprise_success"
  95. # Act
  96. with patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True):
  97. result = enterprise_view()
  98. # Assert
  99. assert result == "enterprise_success"
  100. def test_only_edition_self_hosted_allows_self_hosted(self):
  101. """Test self-hosted edition decorator allows SELF_HOSTED edition"""
  102. # Arrange
  103. @only_edition_self_hosted
  104. def self_hosted_view():
  105. return "self_hosted_success"
  106. # Act
  107. with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
  108. result = self_hosted_view()
  109. # Assert
  110. assert result == "self_hosted_success"
  111. class TestBillingResourceLimits:
  112. """Test billing resource limit decorators"""
  113. def test_should_allow_when_under_resource_limit(self):
  114. """Test that requests are allowed when under resource limits"""
  115. # Arrange
  116. mock_features = MagicMock()
  117. mock_features.billing.enabled = True
  118. mock_features.members.limit = 10
  119. mock_features.members.size = 5
  120. @cloud_edition_billing_resource_check("members")
  121. def add_member():
  122. return "member_added"
  123. # Act
  124. with patch("controllers.console.wraps.current_user"):
  125. with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
  126. result = add_member()
  127. # Assert
  128. assert result == "member_added"
  129. def test_should_reject_when_over_resource_limit(self):
  130. """Test that requests are rejected when over resource limits"""
  131. # Arrange
  132. app = create_app_with_login()
  133. mock_features = MagicMock()
  134. mock_features.billing.enabled = True
  135. mock_features.members.limit = 10
  136. mock_features.members.size = 10
  137. @cloud_edition_billing_resource_check("members")
  138. def add_member():
  139. return "member_added"
  140. # Act & Assert
  141. with app.test_request_context():
  142. with patch("controllers.console.wraps.current_user", MockUser("test_user")):
  143. with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
  144. with pytest.raises(Exception) as exc_info:
  145. add_member()
  146. assert exc_info.value.code == 403
  147. assert "members has reached the limit" in str(exc_info.value.description)
  148. def test_should_check_source_for_documents_limit(self):
  149. """Test document limit checks request source"""
  150. # Arrange
  151. app = create_app_with_login()
  152. mock_features = MagicMock()
  153. mock_features.billing.enabled = True
  154. mock_features.documents_upload_quota.limit = 100
  155. mock_features.documents_upload_quota.size = 100
  156. @cloud_edition_billing_resource_check("documents")
  157. def upload_document():
  158. return "document_uploaded"
  159. # Test 1: Should reject when source is datasets
  160. with app.test_request_context("/?source=datasets"):
  161. with patch("controllers.console.wraps.current_user", MockUser("test_user")):
  162. with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
  163. with pytest.raises(Exception) as exc_info:
  164. upload_document()
  165. assert exc_info.value.code == 403
  166. # Test 2: Should allow when source is not datasets
  167. with app.test_request_context("/?source=other"):
  168. with patch("controllers.console.wraps.current_user", MockUser("test_user")):
  169. with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
  170. result = upload_document()
  171. assert result == "document_uploaded"
  172. class TestRateLimiting:
  173. """Test rate limiting decorator"""
  174. @patch("controllers.console.wraps.redis_client")
  175. @patch("controllers.console.wraps.db")
  176. def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis):
  177. """Test that requests within rate limit are allowed"""
  178. # Arrange
  179. mock_rate_limit = MagicMock()
  180. mock_rate_limit.enabled = True
  181. mock_rate_limit.limit = 10
  182. mock_redis.zcard.return_value = 5 # 5 requests in window
  183. @cloud_edition_billing_rate_limit_check("knowledge")
  184. def knowledge_request():
  185. return "knowledge_success"
  186. # Act
  187. with patch("controllers.console.wraps.current_user"):
  188. with patch(
  189. "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
  190. ):
  191. result = knowledge_request()
  192. # Assert
  193. assert result == "knowledge_success"
  194. mock_redis.zadd.assert_called_once()
  195. mock_redis.zremrangebyscore.assert_called_once()
  196. @patch("controllers.console.wraps.redis_client")
  197. @patch("controllers.console.wraps.db")
  198. def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis):
  199. """Test that requests over rate limit are rejected and logged"""
  200. # Arrange
  201. app = create_app_with_login()
  202. mock_rate_limit = MagicMock()
  203. mock_rate_limit.enabled = True
  204. mock_rate_limit.limit = 10
  205. mock_rate_limit.subscription_plan = "pro"
  206. mock_redis.zcard.return_value = 11 # Over limit
  207. mock_session = MagicMock()
  208. mock_db.session = mock_session
  209. @cloud_edition_billing_rate_limit_check("knowledge")
  210. def knowledge_request():
  211. return "knowledge_success"
  212. # Act & Assert
  213. with app.test_request_context():
  214. with patch("controllers.console.wraps.current_user", MockUser("test_user")):
  215. with patch(
  216. "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
  217. ):
  218. with pytest.raises(Exception) as exc_info:
  219. knowledge_request()
  220. # Verify error
  221. assert exc_info.value.code == 403
  222. assert "rate limit" in str(exc_info.value.description)
  223. # Verify rate limit log was created
  224. mock_session.add.assert_called_once()
  225. mock_session.commit.assert_called_once()
  226. class TestSystemSetup:
  227. """Test system setup decorator"""
  228. @patch("controllers.console.wraps.db")
  229. def test_should_allow_when_setup_complete(self, mock_db):
  230. """Test that requests are allowed when setup is complete"""
  231. # Arrange
  232. mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
  233. @setup_required
  234. def admin_view():
  235. return "admin_success"
  236. # Act
  237. with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
  238. result = admin_view()
  239. # Assert
  240. assert result == "admin_success"
  241. @patch("controllers.console.wraps.db")
  242. @patch("controllers.console.wraps.os.environ.get")
  243. def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
  244. """Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
  245. # Arrange
  246. mock_db.session.query.return_value.first.return_value = None # No setup
  247. mock_environ_get.return_value = "some_password"
  248. @setup_required
  249. def admin_view():
  250. return "admin_success"
  251. # Act & Assert
  252. with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
  253. with pytest.raises(NotInitValidateError):
  254. admin_view()
  255. @patch("controllers.console.wraps.db")
  256. @patch("controllers.console.wraps.os.environ.get")
  257. def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
  258. """Test NotSetupError when no INIT_PASSWORD and setup not complete"""
  259. # Arrange
  260. mock_db.session.query.return_value.first.return_value = None # No setup
  261. mock_environ_get.return_value = None # No INIT_PASSWORD
  262. @setup_required
  263. def admin_view():
  264. return "admin_success"
  265. # Act & Assert
  266. with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
  267. with pytest.raises(NotSetupError):
  268. admin_view()
  269. class TestEnterpriseLicense:
  270. """Test enterprise license decorator"""
  271. def test_should_allow_with_valid_license(self):
  272. """Test that valid licenses allow access"""
  273. # Arrange
  274. mock_settings = MagicMock()
  275. mock_settings.license.status = LicenseStatus.ACTIVE
  276. @enterprise_license_required
  277. def enterprise_feature():
  278. return "enterprise_success"
  279. # Act
  280. with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
  281. result = enterprise_feature()
  282. # Assert
  283. assert result == "enterprise_success"
  284. @pytest.mark.parametrize("invalid_status", [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST])
  285. def test_should_reject_with_invalid_license(self, invalid_status):
  286. """Test that invalid licenses raise UnauthorizedAndForceLogout"""
  287. # Arrange
  288. mock_settings = MagicMock()
  289. mock_settings.license.status = invalid_status
  290. @enterprise_license_required
  291. def enterprise_feature():
  292. return "enterprise_success"
  293. # Act & Assert
  294. with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
  295. with pytest.raises(UnauthorizedAndForceLogout) as exc_info:
  296. enterprise_feature()
  297. assert "license is invalid" in str(exc_info.value)