| @@ -0,0 +1,474 @@ | |||
| from unittest.mock import MagicMock, patch | |||
| import pytest | |||
| from faker import Faker | |||
| from models.account import TenantAccountJoin, TenantAccountRole | |||
| from models.model import Account, Tenant | |||
| from models.provider import LoadBalancingModelConfig, Provider, ProviderModelSetting | |||
| from services.model_load_balancing_service import ModelLoadBalancingService | |||
| class TestModelLoadBalancingService: | |||
| """Integration tests for ModelLoadBalancingService using testcontainers.""" | |||
| @pytest.fixture | |||
| def mock_external_service_dependencies(self): | |||
| """Mock setup for external service dependencies.""" | |||
| with ( | |||
| patch("services.model_load_balancing_service.ProviderManager") as mock_provider_manager, | |||
| patch("services.model_load_balancing_service.LBModelManager") as mock_lb_model_manager, | |||
| patch("services.model_load_balancing_service.ModelProviderFactory") as mock_model_provider_factory, | |||
| patch("services.model_load_balancing_service.encrypter") as mock_encrypter, | |||
| ): | |||
| # Setup default mock returns | |||
| mock_provider_manager_instance = mock_provider_manager.return_value | |||
| # Mock provider configuration | |||
| mock_provider_config = MagicMock() | |||
| mock_provider_config.provider.provider = "openai" | |||
| mock_provider_config.custom_configuration.provider = None | |||
| # Mock provider model setting | |||
| mock_provider_model_setting = MagicMock() | |||
| mock_provider_model_setting.load_balancing_enabled = False | |||
| mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting | |||
| # Mock provider configurations dict | |||
| mock_provider_configs = {"openai": mock_provider_config} | |||
| mock_provider_manager_instance.get_configurations.return_value = mock_provider_configs | |||
| # Mock LBModelManager | |||
| mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) | |||
| # Mock ModelProviderFactory | |||
| mock_model_provider_factory_instance = mock_model_provider_factory.return_value | |||
| # Mock credential schemas | |||
| mock_credential_schema = MagicMock() | |||
| mock_credential_schema.credential_form_schemas = [] | |||
| # Mock provider configuration methods | |||
| mock_provider_config.extract_secret_variables.return_value = [] | |||
| mock_provider_config.obfuscated_credentials.return_value = {} | |||
| mock_provider_config._get_credential_schema.return_value = mock_credential_schema | |||
| yield { | |||
| "provider_manager": mock_provider_manager, | |||
| "lb_model_manager": mock_lb_model_manager, | |||
| "model_provider_factory": mock_model_provider_factory, | |||
| "encrypter": mock_encrypter, | |||
| "provider_config": mock_provider_config, | |||
| "provider_model_setting": mock_provider_model_setting, | |||
| "credential_schema": mock_credential_schema, | |||
| } | |||
| def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): | |||
| """ | |||
| Helper method to create a test account and tenant for testing. | |||
| Args: | |||
| db_session_with_containers: Database session from testcontainers infrastructure | |||
| mock_external_service_dependencies: Mock dependencies | |||
| Returns: | |||
| tuple: (account, tenant) - Created account and tenant instances | |||
| """ | |||
| fake = Faker() | |||
| # Create account | |||
| account = Account( | |||
| email=fake.email(), | |||
| name=fake.name(), | |||
| interface_language="en-US", | |||
| status="active", | |||
| ) | |||
| from extensions.ext_database import db | |||
| db.session.add(account) | |||
| db.session.commit() | |||
| # Create tenant for the account | |||
| tenant = Tenant( | |||
| name=fake.company(), | |||
| status="normal", | |||
| ) | |||
| db.session.add(tenant) | |||
| db.session.commit() | |||
| # Create tenant-account join | |||
| join = TenantAccountJoin( | |||
| tenant_id=tenant.id, | |||
| account_id=account.id, | |||
| role=TenantAccountRole.OWNER.value, | |||
| current=True, | |||
| ) | |||
| db.session.add(join) | |||
| db.session.commit() | |||
| # Set current tenant for account | |||
| account.current_tenant = tenant | |||
| return account, tenant | |||
| def _create_test_provider_and_setting( | |||
| self, db_session_with_containers, tenant_id, mock_external_service_dependencies | |||
| ): | |||
| """ | |||
| Helper method to create a test provider and provider model setting. | |||
| Args: | |||
| db_session_with_containers: Database session from testcontainers infrastructure | |||
| tenant_id: Tenant ID for the provider | |||
| mock_external_service_dependencies: Mock dependencies | |||
| Returns: | |||
| tuple: (provider, provider_model_setting) - Created provider and setting instances | |||
| """ | |||
| fake = Faker() | |||
| from extensions.ext_database import db | |||
| # Create provider | |||
| provider = Provider( | |||
| tenant_id=tenant_id, | |||
| provider_name="openai", | |||
| provider_type="custom", | |||
| is_valid=True, | |||
| ) | |||
| db.session.add(provider) | |||
| db.session.commit() | |||
| # Create provider model setting | |||
| provider_model_setting = ProviderModelSetting( | |||
| tenant_id=tenant_id, | |||
| provider_name="openai", | |||
| model_name="gpt-3.5-turbo", | |||
| model_type="text-generation", # Use the origin model type that matches the query | |||
| enabled=True, | |||
| load_balancing_enabled=False, | |||
| ) | |||
| db.session.add(provider_model_setting) | |||
| db.session.commit() | |||
| return provider, provider_model_setting | |||
| def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): | |||
| """ | |||
| Test successful model load balancing enablement. | |||
| This test verifies: | |||
| - Proper provider configuration retrieval | |||
| - Successful enablement of model load balancing | |||
| - Correct method calls to provider configuration | |||
| """ | |||
| # Arrange: Create test data | |||
| fake = Faker() | |||
| account, tenant = self._create_test_account_and_tenant( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| provider, provider_model_setting = self._create_test_provider_and_setting( | |||
| db_session_with_containers, tenant.id, mock_external_service_dependencies | |||
| ) | |||
| # Setup mocks for enable method | |||
| mock_provider_config = mock_external_service_dependencies["provider_config"] | |||
| mock_provider_config.enable_model_load_balancing = MagicMock() | |||
| # Act: Execute the method under test | |||
| service = ModelLoadBalancingService() | |||
| service.enable_model_load_balancing( | |||
| tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" | |||
| ) | |||
| # Assert: Verify the expected outcomes | |||
| mock_provider_config.enable_model_load_balancing.assert_called_once() | |||
| call_args = mock_provider_config.enable_model_load_balancing.call_args | |||
| assert call_args.kwargs["model"] == "gpt-3.5-turbo" | |||
| assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value | |||
| # Verify database state | |||
| from extensions.ext_database import db | |||
| db.session.refresh(provider) | |||
| db.session.refresh(provider_model_setting) | |||
| assert provider.id is not None | |||
| assert provider_model_setting.id is not None | |||
| def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): | |||
| """ | |||
| Test successful model load balancing disablement. | |||
| This test verifies: | |||
| - Proper provider configuration retrieval | |||
| - Successful disablement of model load balancing | |||
| - Correct method calls to provider configuration | |||
| """ | |||
| # Arrange: Create test data | |||
| fake = Faker() | |||
| account, tenant = self._create_test_account_and_tenant( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| provider, provider_model_setting = self._create_test_provider_and_setting( | |||
| db_session_with_containers, tenant.id, mock_external_service_dependencies | |||
| ) | |||
| # Setup mocks for disable method | |||
| mock_provider_config = mock_external_service_dependencies["provider_config"] | |||
| mock_provider_config.disable_model_load_balancing = MagicMock() | |||
| # Act: Execute the method under test | |||
| service = ModelLoadBalancingService() | |||
| service.disable_model_load_balancing( | |||
| tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" | |||
| ) | |||
| # Assert: Verify the expected outcomes | |||
| mock_provider_config.disable_model_load_balancing.assert_called_once() | |||
| call_args = mock_provider_config.disable_model_load_balancing.call_args | |||
| assert call_args.kwargs["model"] == "gpt-3.5-turbo" | |||
| assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value | |||
| # Verify database state | |||
| from extensions.ext_database import db | |||
| db.session.refresh(provider) | |||
| db.session.refresh(provider_model_setting) | |||
| assert provider.id is not None | |||
| assert provider_model_setting.id is not None | |||
| def test_enable_model_load_balancing_provider_not_found( | |||
| self, db_session_with_containers, mock_external_service_dependencies | |||
| ): | |||
| """ | |||
| Test error handling when provider does not exist. | |||
| This test verifies: | |||
| - Proper error handling for non-existent provider | |||
| - Correct exception type and message | |||
| - No database state changes | |||
| """ | |||
| # Arrange: Create test data | |||
| fake = Faker() | |||
| account, tenant = self._create_test_account_and_tenant( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| # Setup mocks to return empty provider configurations | |||
| mock_provider_manager = mock_external_service_dependencies["provider_manager"] | |||
| mock_provider_manager_instance = mock_provider_manager.return_value | |||
| mock_provider_manager_instance.get_configurations.return_value = {} | |||
| # Act & Assert: Verify proper error handling | |||
| service = ModelLoadBalancingService() | |||
| with pytest.raises(ValueError) as exc_info: | |||
| service.enable_model_load_balancing( | |||
| tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm" | |||
| ) | |||
| # Verify correct error message | |||
| assert "Provider nonexistent_provider does not exist." in str(exc_info.value) | |||
| # Verify no database state changes occurred | |||
| from extensions.ext_database import db | |||
| db.session.rollback() | |||
| def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies): | |||
| """ | |||
| Test successful retrieval of load balancing configurations. | |||
| This test verifies: | |||
| - Proper provider configuration retrieval | |||
| - Successful database query for load balancing configs | |||
| - Correct return format and data structure | |||
| """ | |||
| # Arrange: Create test data | |||
| fake = Faker() | |||
| account, tenant = self._create_test_account_and_tenant( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| provider, provider_model_setting = self._create_test_provider_and_setting( | |||
| db_session_with_containers, tenant.id, mock_external_service_dependencies | |||
| ) | |||
| # Create load balancing config | |||
| from extensions.ext_database import db | |||
| load_balancing_config = LoadBalancingModelConfig( | |||
| tenant_id=tenant.id, | |||
| provider_name="openai", | |||
| model_name="gpt-3.5-turbo", | |||
| model_type="text-generation", # Use the origin model type that matches the query | |||
| name="config1", | |||
| encrypted_config='{"api_key": "test_key"}', | |||
| enabled=True, | |||
| ) | |||
| db.session.add(load_balancing_config) | |||
| db.session.commit() | |||
| # Verify the config was created | |||
| db.session.refresh(load_balancing_config) | |||
| assert load_balancing_config.id is not None | |||
| # Setup mocks for get_load_balancing_configs method | |||
| mock_provider_config = mock_external_service_dependencies["provider_config"] | |||
| mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"] | |||
| mock_provider_model_setting.load_balancing_enabled = True | |||
| # Mock credential schema methods | |||
| mock_credential_schema = mock_external_service_dependencies["credential_schema"] | |||
| mock_credential_schema.credential_form_schemas = [] | |||
| # Mock encrypter | |||
| mock_encrypter = mock_external_service_dependencies["encrypter"] | |||
| mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher") | |||
| # Mock _get_credential_schema method | |||
| mock_provider_config._get_credential_schema.return_value = mock_credential_schema | |||
| # Mock extract_secret_variables method | |||
| mock_provider_config.extract_secret_variables.return_value = [] | |||
| # Mock obfuscated_credentials method | |||
| mock_provider_config.obfuscated_credentials.return_value = {} | |||
| # Mock LBModelManager.get_config_in_cooldown_and_ttl | |||
| mock_lb_model_manager = mock_external_service_dependencies["lb_model_manager"] | |||
| mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) | |||
| # Act: Execute the method under test | |||
| service = ModelLoadBalancingService() | |||
| is_enabled, configs = service.get_load_balancing_configs( | |||
| tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" | |||
| ) | |||
| # Assert: Verify the expected outcomes | |||
| assert is_enabled is True | |||
| assert len(configs) == 1 | |||
| assert configs[0]["id"] == load_balancing_config.id | |||
| assert configs[0]["name"] == "config1" | |||
| assert configs[0]["enabled"] is True | |||
| assert configs[0]["in_cooldown"] is False | |||
| assert configs[0]["ttl"] == 0 | |||
| # Verify database state | |||
| db.session.refresh(load_balancing_config) | |||
| assert load_balancing_config.id is not None | |||
| def test_get_load_balancing_configs_provider_not_found( | |||
| self, db_session_with_containers, mock_external_service_dependencies | |||
| ): | |||
| """ | |||
| Test error handling when provider does not exist in get_load_balancing_configs. | |||
| This test verifies: | |||
| - Proper error handling for non-existent provider | |||
| - Correct exception type and message | |||
| - No database state changes | |||
| """ | |||
| # Arrange: Create test data | |||
| fake = Faker() | |||
| account, tenant = self._create_test_account_and_tenant( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| # Setup mocks to return empty provider configurations | |||
| mock_provider_manager = mock_external_service_dependencies["provider_manager"] | |||
| mock_provider_manager_instance = mock_provider_manager.return_value | |||
| mock_provider_manager_instance.get_configurations.return_value = {} | |||
| # Act & Assert: Verify proper error handling | |||
| service = ModelLoadBalancingService() | |||
| with pytest.raises(ValueError) as exc_info: | |||
| service.get_load_balancing_configs( | |||
| tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm" | |||
| ) | |||
| # Verify correct error message | |||
| assert "Provider nonexistent_provider does not exist." in str(exc_info.value) | |||
| # Verify no database state changes occurred | |||
| from extensions.ext_database import db | |||
| db.session.rollback() | |||
| def test_get_load_balancing_configs_with_inherit_config( | |||
| self, db_session_with_containers, mock_external_service_dependencies | |||
| ): | |||
| """ | |||
| Test load balancing configs retrieval with inherit configuration. | |||
| This test verifies: | |||
| - Proper handling of inherit configuration | |||
| - Correct ordering of configurations | |||
| - Inherit config initialization when needed | |||
| """ | |||
| # Arrange: Create test data | |||
| fake = Faker() | |||
| account, tenant = self._create_test_account_and_tenant( | |||
| db_session_with_containers, mock_external_service_dependencies | |||
| ) | |||
| provider, provider_model_setting = self._create_test_provider_and_setting( | |||
| db_session_with_containers, tenant.id, mock_external_service_dependencies | |||
| ) | |||
| # Create load balancing config | |||
| from extensions.ext_database import db | |||
| load_balancing_config = LoadBalancingModelConfig( | |||
| tenant_id=tenant.id, | |||
| provider_name="openai", | |||
| model_name="gpt-3.5-turbo", | |||
| model_type="text-generation", # Use the origin model type that matches the query | |||
| name="config1", | |||
| encrypted_config='{"api_key": "test_key"}', | |||
| enabled=True, | |||
| ) | |||
| db.session.add(load_balancing_config) | |||
| db.session.commit() | |||
| # Setup mocks for inherit config scenario | |||
| mock_provider_config = mock_external_service_dependencies["provider_config"] | |||
| mock_provider_config.custom_configuration.provider = MagicMock() # Enable custom config | |||
| mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"] | |||
| mock_provider_model_setting.load_balancing_enabled = True | |||
| # Mock credential schema methods | |||
| mock_credential_schema = mock_external_service_dependencies["credential_schema"] | |||
| mock_credential_schema.credential_form_schemas = [] | |||
| # Mock encrypter | |||
| mock_encrypter = mock_external_service_dependencies["encrypter"] | |||
| mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher") | |||
| # Act: Execute the method under test | |||
| service = ModelLoadBalancingService() | |||
| is_enabled, configs = service.get_load_balancing_configs( | |||
| tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" | |||
| ) | |||
| # Assert: Verify the expected outcomes | |||
| assert is_enabled is True | |||
| assert len(configs) == 2 # inherit config + existing config | |||
| # First config should be inherit config | |||
| assert configs[0]["name"] == "__inherit__" | |||
| assert configs[0]["enabled"] is True | |||
| # Second config should be the existing config | |||
| assert configs[1]["id"] == load_balancing_config.id | |||
| assert configs[1]["name"] == "config1" | |||
| # Verify database state | |||
| db.session.refresh(load_balancing_config) | |||
| assert load_balancing_config.id is not None | |||
| # Verify inherit config was created in database | |||
| inherit_configs = ( | |||
| db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.name == "__inherit__").all() | |||
| ) | |||
| assert len(inherit_configs) == 1 | |||