| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 |
- from unittest.mock import MagicMock, patch
-
- import pytest
- from flask import Flask
- from flask_login import LoginManager, UserMixin
-
- from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
- from controllers.console.workspace.error import AccountNotInitializedError
- from controllers.console.wraps import (
- account_initialization_required,
- cloud_edition_billing_rate_limit_check,
- cloud_edition_billing_resource_check,
- enterprise_license_required,
- only_edition_cloud,
- only_edition_enterprise,
- only_edition_self_hosted,
- setup_required,
- )
- from models.account import AccountStatus
- from services.feature_service import LicenseStatus
-
-
- class MockUser(UserMixin):
- """Simple User class for testing."""
-
- def __init__(self, user_id: str):
- self.id = user_id
- self.current_tenant_id = "tenant123"
-
- def get_id(self) -> str:
- return self.id
-
-
- def create_app_with_login():
- """Create a Flask app with LoginManager configured."""
- app = Flask(__name__)
- app.config["SECRET_KEY"] = "test-secret-key"
-
- login_manager = LoginManager()
- login_manager.init_app(app)
-
- @login_manager.user_loader
- def load_user(user_id: str):
- return MockUser(user_id)
-
- return app
-
-
- class TestAccountInitialization:
- """Test account initialization decorator"""
-
- def test_should_allow_initialized_account(self):
- """Test that initialized accounts can access protected views"""
- # Arrange
- mock_user = MagicMock()
- mock_user.status = AccountStatus.ACTIVE
-
- @account_initialization_required
- def protected_view():
- return "success"
-
- # Act
- with patch("controllers.console.wraps.current_user", mock_user):
- result = protected_view()
-
- # Assert
- assert result == "success"
-
- def test_should_reject_uninitialized_account(self):
- """Test that uninitialized accounts raise AccountNotInitializedError"""
- # Arrange
- mock_user = MagicMock()
- mock_user.status = AccountStatus.UNINITIALIZED
-
- @account_initialization_required
- def protected_view():
- return "success"
-
- # Act & Assert
- with patch("controllers.console.wraps.current_user", mock_user):
- with pytest.raises(AccountNotInitializedError):
- protected_view()
-
-
- class TestEditionChecks:
- """Test edition-specific decorators"""
-
- def test_only_edition_cloud_allows_cloud_edition(self):
- """Test cloud edition decorator allows CLOUD edition"""
-
- # Arrange
- @only_edition_cloud
- def cloud_view():
- return "cloud_success"
-
- # Act
- with patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"):
- result = cloud_view()
-
- # Assert
- assert result == "cloud_success"
-
- def test_only_edition_cloud_rejects_other_editions(self):
- """Test cloud edition decorator rejects non-CLOUD editions"""
- # Arrange
- app = Flask(__name__)
-
- @only_edition_cloud
- def cloud_view():
- return "cloud_success"
-
- # Act & Assert
- with app.test_request_context():
- with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
- with pytest.raises(Exception) as exc_info:
- cloud_view()
- assert exc_info.value.code == 404
-
- def test_only_edition_enterprise_allows_when_enabled(self):
- """Test enterprise edition decorator allows when ENTERPRISE_ENABLED is True"""
-
- # Arrange
- @only_edition_enterprise
- def enterprise_view():
- return "enterprise_success"
-
- # Act
- with patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True):
- result = enterprise_view()
-
- # Assert
- assert result == "enterprise_success"
-
- def test_only_edition_self_hosted_allows_self_hosted(self):
- """Test self-hosted edition decorator allows SELF_HOSTED edition"""
-
- # Arrange
- @only_edition_self_hosted
- def self_hosted_view():
- return "self_hosted_success"
-
- # Act
- with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
- result = self_hosted_view()
-
- # Assert
- assert result == "self_hosted_success"
-
-
- class TestBillingResourceLimits:
- """Test billing resource limit decorators"""
-
- def test_should_allow_when_under_resource_limit(self):
- """Test that requests are allowed when under resource limits"""
- # Arrange
- mock_features = MagicMock()
- mock_features.billing.enabled = True
- mock_features.members.limit = 10
- mock_features.members.size = 5
-
- @cloud_edition_billing_resource_check("members")
- def add_member():
- return "member_added"
-
- # Act
- with patch("controllers.console.wraps.current_user"):
- with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
- result = add_member()
-
- # Assert
- assert result == "member_added"
-
- def test_should_reject_when_over_resource_limit(self):
- """Test that requests are rejected when over resource limits"""
- # Arrange
- app = create_app_with_login()
- mock_features = MagicMock()
- mock_features.billing.enabled = True
- mock_features.members.limit = 10
- mock_features.members.size = 10
-
- @cloud_edition_billing_resource_check("members")
- def add_member():
- return "member_added"
-
- # Act & Assert
- with app.test_request_context():
- with patch("controllers.console.wraps.current_user", MockUser("test_user")):
- with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
- with pytest.raises(Exception) as exc_info:
- add_member()
- assert exc_info.value.code == 403
- assert "members has reached the limit" in str(exc_info.value.description)
-
- def test_should_check_source_for_documents_limit(self):
- """Test document limit checks request source"""
- # Arrange
- app = create_app_with_login()
- mock_features = MagicMock()
- mock_features.billing.enabled = True
- mock_features.documents_upload_quota.limit = 100
- mock_features.documents_upload_quota.size = 100
-
- @cloud_edition_billing_resource_check("documents")
- def upload_document():
- return "document_uploaded"
-
- # Test 1: Should reject when source is datasets
- with app.test_request_context("/?source=datasets"):
- with patch("controllers.console.wraps.current_user", MockUser("test_user")):
- with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
- with pytest.raises(Exception) as exc_info:
- upload_document()
- assert exc_info.value.code == 403
-
- # Test 2: Should allow when source is not datasets
- with app.test_request_context("/?source=other"):
- with patch("controllers.console.wraps.current_user", MockUser("test_user")):
- with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
- result = upload_document()
- assert result == "document_uploaded"
-
-
- class TestRateLimiting:
- """Test rate limiting decorator"""
-
- @patch("controllers.console.wraps.redis_client")
- @patch("controllers.console.wraps.db")
- def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis):
- """Test that requests within rate limit are allowed"""
- # Arrange
- mock_rate_limit = MagicMock()
- mock_rate_limit.enabled = True
- mock_rate_limit.limit = 10
- mock_redis.zcard.return_value = 5 # 5 requests in window
-
- @cloud_edition_billing_rate_limit_check("knowledge")
- def knowledge_request():
- return "knowledge_success"
-
- # Act
- with patch("controllers.console.wraps.current_user"):
- with patch(
- "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
- ):
- result = knowledge_request()
-
- # Assert
- assert result == "knowledge_success"
- mock_redis.zadd.assert_called_once()
- mock_redis.zremrangebyscore.assert_called_once()
-
- @patch("controllers.console.wraps.redis_client")
- @patch("controllers.console.wraps.db")
- def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis):
- """Test that requests over rate limit are rejected and logged"""
- # Arrange
- app = create_app_with_login()
- mock_rate_limit = MagicMock()
- mock_rate_limit.enabled = True
- mock_rate_limit.limit = 10
- mock_rate_limit.subscription_plan = "pro"
- mock_redis.zcard.return_value = 11 # Over limit
-
- mock_session = MagicMock()
- mock_db.session = mock_session
-
- @cloud_edition_billing_rate_limit_check("knowledge")
- def knowledge_request():
- return "knowledge_success"
-
- # Act & Assert
- with app.test_request_context():
- with patch("controllers.console.wraps.current_user", MockUser("test_user")):
- with patch(
- "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
- ):
- with pytest.raises(Exception) as exc_info:
- knowledge_request()
-
- # Verify error
- assert exc_info.value.code == 403
- assert "rate limit" in str(exc_info.value.description)
-
- # Verify rate limit log was created
- mock_session.add.assert_called_once()
- mock_session.commit.assert_called_once()
-
-
- class TestSystemSetup:
- """Test system setup decorator"""
-
- @patch("controllers.console.wraps.db")
- def test_should_allow_when_setup_complete(self, mock_db):
- """Test that requests are allowed when setup is complete"""
- # Arrange
- mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
-
- @setup_required
- def admin_view():
- return "admin_success"
-
- # Act
- with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
- result = admin_view()
-
- # Assert
- assert result == "admin_success"
-
- @patch("controllers.console.wraps.db")
- @patch("controllers.console.wraps.os.environ.get")
- def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
- """Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
- # Arrange
- mock_db.session.query.return_value.first.return_value = None # No setup
- mock_environ_get.return_value = "some_password"
-
- @setup_required
- def admin_view():
- return "admin_success"
-
- # Act & Assert
- with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
- with pytest.raises(NotInitValidateError):
- admin_view()
-
- @patch("controllers.console.wraps.db")
- @patch("controllers.console.wraps.os.environ.get")
- def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
- """Test NotSetupError when no INIT_PASSWORD and setup not complete"""
- # Arrange
- mock_db.session.query.return_value.first.return_value = None # No setup
- mock_environ_get.return_value = None # No INIT_PASSWORD
-
- @setup_required
- def admin_view():
- return "admin_success"
-
- # Act & Assert
- with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
- with pytest.raises(NotSetupError):
- admin_view()
-
-
- class TestEnterpriseLicense:
- """Test enterprise license decorator"""
-
- def test_should_allow_with_valid_license(self):
- """Test that valid licenses allow access"""
- # Arrange
- mock_settings = MagicMock()
- mock_settings.license.status = LicenseStatus.ACTIVE
-
- @enterprise_license_required
- def enterprise_feature():
- return "enterprise_success"
-
- # Act
- with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
- result = enterprise_feature()
-
- # Assert
- assert result == "enterprise_success"
-
- @pytest.mark.parametrize("invalid_status", [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST])
- def test_should_reject_with_invalid_license(self, invalid_status):
- """Test that invalid licenses raise UnauthorizedAndForceLogout"""
- # Arrange
- mock_settings = MagicMock()
- mock_settings.license.status = invalid_status
-
- @enterprise_license_required
- def enterprise_feature():
- return "enterprise_success"
-
- # Act & Assert
- with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
- with pytest.raises(UnauthorizedAndForceLogout) as exc_info:
- enterprise_feature()
- assert "license is invalid" in str(exc_info.value)
|