Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.5.0
| from .create_document_index import handle | from .create_document_index import handle | ||||
| from .create_installed_app_when_app_created import handle | from .create_installed_app_when_app_created import handle | ||||
| from .create_site_record_when_app_created import handle | from .create_site_record_when_app_created import handle | ||||
| from .deduct_quota_when_message_created import handle | |||||
| from .delete_tool_parameters_cache_when_sync_draft_workflow import handle | from .delete_tool_parameters_cache_when_sync_draft_workflow import handle | ||||
| from .update_app_dataset_join_when_app_model_config_updated import handle | from .update_app_dataset_join_when_app_model_config_updated import handle | ||||
| from .update_app_dataset_join_when_app_published_workflow_updated import handle | from .update_app_dataset_join_when_app_published_workflow_updated import handle | ||||
| from .update_provider_last_used_at_when_message_created import handle | |||||
| # Consolidated handler replaces both deduct_quota_when_message_created and | |||||
| # update_provider_last_used_at_when_message_created | |||||
| from .update_provider_when_message_created import handle |
| from datetime import UTC, datetime | |||||
| from configs import dify_config | |||||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity | |||||
| from core.entities.provider_entities import QuotaUnit | |||||
| from core.plugin.entities.plugin import ModelProviderID | |||||
| from events.message_event import message_was_created | |||||
| from extensions.ext_database import db | |||||
| from models.provider import Provider, ProviderType | |||||
| @message_was_created.connect | |||||
| def handle(sender, **kwargs): | |||||
| message = sender | |||||
| application_generate_entity = kwargs.get("application_generate_entity") | |||||
| if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): | |||||
| return | |||||
| model_config = application_generate_entity.model_conf | |||||
| provider_model_bundle = model_config.provider_model_bundle | |||||
| provider_configuration = provider_model_bundle.configuration | |||||
| if provider_configuration.using_provider_type != ProviderType.SYSTEM: | |||||
| return | |||||
| system_configuration = provider_configuration.system_configuration | |||||
| if not system_configuration.current_quota_type: | |||||
| return | |||||
| quota_unit = None | |||||
| for quota_configuration in system_configuration.quota_configurations: | |||||
| if quota_configuration.quota_type == system_configuration.current_quota_type: | |||||
| quota_unit = quota_configuration.quota_unit | |||||
| if quota_configuration.quota_limit == -1: | |||||
| return | |||||
| break | |||||
| used_quota = None | |||||
| if quota_unit: | |||||
| if quota_unit == QuotaUnit.TOKENS: | |||||
| used_quota = message.message_tokens + message.answer_tokens | |||||
| elif quota_unit == QuotaUnit.CREDITS: | |||||
| used_quota = dify_config.get_model_credits(model_config.model) | |||||
| else: | |||||
| used_quota = 1 | |||||
| if used_quota is not None and system_configuration.current_quota_type is not None: | |||||
| db.session.query(Provider).filter( | |||||
| Provider.tenant_id == application_generate_entity.app_config.tenant_id, | |||||
| # TODO: Use provider name with prefix after the data migration. | |||||
| Provider.provider_name == ModelProviderID(model_config.provider).provider_name, | |||||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||||
| Provider.quota_type == system_configuration.current_quota_type.value, | |||||
| Provider.quota_limit > Provider.quota_used, | |||||
| ).update( | |||||
| { | |||||
| "quota_used": Provider.quota_used + used_quota, | |||||
| "last_used": datetime.now(tz=UTC).replace(tzinfo=None), | |||||
| } | |||||
| ) | |||||
| db.session.commit() |
| from datetime import UTC, datetime | |||||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity | |||||
| from events.message_event import message_was_created | |||||
| from extensions.ext_database import db | |||||
| from models.provider import Provider | |||||
| @message_was_created.connect | |||||
| def handle(sender, **kwargs): | |||||
| application_generate_entity = kwargs.get("application_generate_entity") | |||||
| if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): | |||||
| return | |||||
| db.session.query(Provider).filter( | |||||
| Provider.tenant_id == application_generate_entity.app_config.tenant_id, | |||||
| Provider.provider_name == application_generate_entity.model_conf.provider, | |||||
| ).update({"last_used": datetime.now(UTC).replace(tzinfo=None)}) | |||||
| db.session.commit() |
| import logging | |||||
| import time as time_module | |||||
| from datetime import datetime | |||||
| from typing import Any, Optional | |||||
| from pydantic import BaseModel | |||||
| from sqlalchemy import update | |||||
| from configs import dify_config | |||||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity | |||||
| from core.entities.provider_entities import QuotaUnit, SystemConfiguration | |||||
| from core.plugin.entities.plugin import ModelProviderID | |||||
| from events.message_event import message_was_created | |||||
| from extensions.ext_database import db | |||||
| from libs import datetime_utils | |||||
| from models.model import Message | |||||
| from models.provider import Provider, ProviderType | |||||
| logger = logging.getLogger(__name__) | |||||
| class _ProviderUpdateFilters(BaseModel): | |||||
| """Filters for identifying Provider records to update.""" | |||||
| tenant_id: str | |||||
| provider_name: str | |||||
| provider_type: Optional[str] = None | |||||
| quota_type: Optional[str] = None | |||||
| class _ProviderUpdateAdditionalFilters(BaseModel): | |||||
| """Additional filters for Provider updates.""" | |||||
| quota_limit_check: bool = False | |||||
| class _ProviderUpdateValues(BaseModel): | |||||
| """Values to update in Provider records.""" | |||||
| last_used: Optional[datetime] = None | |||||
| quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression | |||||
| class _ProviderUpdateOperation(BaseModel): | |||||
| """A single Provider update operation.""" | |||||
| filters: _ProviderUpdateFilters | |||||
| values: _ProviderUpdateValues | |||||
| additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters() | |||||
| description: str = "unknown" | |||||
| @message_was_created.connect | |||||
| def handle(sender: Message, **kwargs): | |||||
| """ | |||||
| Consolidated handler for Provider updates when a message is created. | |||||
| This handler replaces both: | |||||
| - update_provider_last_used_at_when_message_created | |||||
| - deduct_quota_when_message_created | |||||
| By performing all Provider updates in a single transaction, we ensure | |||||
| consistency and efficiency when updating Provider records. | |||||
| """ | |||||
| message = sender | |||||
| application_generate_entity = kwargs.get("application_generate_entity") | |||||
| if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): | |||||
| return | |||||
| tenant_id = application_generate_entity.app_config.tenant_id | |||||
| provider_name = application_generate_entity.model_conf.provider | |||||
| current_time = datetime_utils.naive_utc_now() | |||||
| # Prepare updates for both scenarios | |||||
| updates_to_perform: list[_ProviderUpdateOperation] = [] | |||||
| # 1. Always update last_used for the provider | |||||
| basic_update = _ProviderUpdateOperation( | |||||
| filters=_ProviderUpdateFilters( | |||||
| tenant_id=tenant_id, | |||||
| provider_name=provider_name, | |||||
| ), | |||||
| values=_ProviderUpdateValues(last_used=current_time), | |||||
| description="basic_last_used_update", | |||||
| ) | |||||
| updates_to_perform.append(basic_update) | |||||
| # 2. Check if we need to deduct quota (system provider only) | |||||
| model_config = application_generate_entity.model_conf | |||||
| provider_model_bundle = model_config.provider_model_bundle | |||||
| provider_configuration = provider_model_bundle.configuration | |||||
| if ( | |||||
| provider_configuration.using_provider_type == ProviderType.SYSTEM | |||||
| and provider_configuration.system_configuration | |||||
| and provider_configuration.system_configuration.current_quota_type is not None | |||||
| ): | |||||
| system_configuration = provider_configuration.system_configuration | |||||
| # Calculate quota usage | |||||
| used_quota = _calculate_quota_usage( | |||||
| message=message, | |||||
| system_configuration=system_configuration, | |||||
| model_name=model_config.model, | |||||
| ) | |||||
| if used_quota is not None: | |||||
| quota_update = _ProviderUpdateOperation( | |||||
| filters=_ProviderUpdateFilters( | |||||
| tenant_id=tenant_id, | |||||
| provider_name=ModelProviderID(model_config.provider).provider_name, | |||||
| provider_type=ProviderType.SYSTEM.value, | |||||
| quota_type=provider_configuration.system_configuration.current_quota_type.value, | |||||
| ), | |||||
| values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), | |||||
| additional_filters=_ProviderUpdateAdditionalFilters( | |||||
| quota_limit_check=True # Provider.quota_limit > Provider.quota_used | |||||
| ), | |||||
| description="quota_deduction_update", | |||||
| ) | |||||
| updates_to_perform.append(quota_update) | |||||
| # Execute all updates | |||||
| start_time = time_module.perf_counter() | |||||
| try: | |||||
| _execute_provider_updates(updates_to_perform) | |||||
| # Log successful completion with timing | |||||
| duration = time_module.perf_counter() - start_time | |||||
| logger.info( | |||||
| f"Provider updates completed successfully. " | |||||
| f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, " | |||||
| f"Tenant: {tenant_id}, Provider: {provider_name}" | |||||
| ) | |||||
| except Exception as e: | |||||
| # Log failure with timing and context | |||||
| duration = time_module.perf_counter() - start_time | |||||
| logger.exception( | |||||
| f"Provider updates failed after {duration:.3f}s. " | |||||
| f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, " | |||||
| f"Provider: {provider_name}" | |||||
| ) | |||||
| raise | |||||
| def _calculate_quota_usage( | |||||
| *, message: Message, system_configuration: SystemConfiguration, model_name: str | |||||
| ) -> Optional[int]: | |||||
| """Calculate quota usage based on message tokens and quota type.""" | |||||
| quota_unit = None | |||||
| for quota_configuration in system_configuration.quota_configurations: | |||||
| if quota_configuration.quota_type == system_configuration.current_quota_type: | |||||
| quota_unit = quota_configuration.quota_unit | |||||
| if quota_configuration.quota_limit == -1: | |||||
| return None | |||||
| break | |||||
| if quota_unit is None: | |||||
| return None | |||||
| try: | |||||
| if quota_unit == QuotaUnit.TOKENS: | |||||
| tokens = message.message_tokens + message.answer_tokens | |||||
| return tokens | |||||
| if quota_unit == QuotaUnit.CREDITS: | |||||
| tokens = dify_config.get_model_credits(model_name) | |||||
| return tokens | |||||
| elif quota_unit == QuotaUnit.TIMES: | |||||
| return 1 | |||||
| return None | |||||
| except Exception as e: | |||||
| logger.exception("Failed to calculate quota usage") | |||||
| return None | |||||
| def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]): | |||||
| """Execute all Provider updates in a single transaction.""" | |||||
| if not updates_to_perform: | |||||
| return | |||||
| # Use SQLAlchemy's context manager for transaction management | |||||
| # This automatically handles commit/rollback | |||||
| with db.session.begin(): | |||||
| # Use a single transaction for all updates | |||||
| for update_operation in updates_to_perform: | |||||
| filters = update_operation.filters | |||||
| values = update_operation.values | |||||
| additional_filters = update_operation.additional_filters | |||||
| description = update_operation.description | |||||
| # Build the where conditions | |||||
| where_conditions = [ | |||||
| Provider.tenant_id == filters.tenant_id, | |||||
| Provider.provider_name == filters.provider_name, | |||||
| ] | |||||
| # Add additional filters if specified | |||||
| if filters.provider_type is not None: | |||||
| where_conditions.append(Provider.provider_type == filters.provider_type) | |||||
| if filters.quota_type is not None: | |||||
| where_conditions.append(Provider.quota_type == filters.quota_type) | |||||
| if additional_filters.quota_limit_check: | |||||
| where_conditions.append(Provider.quota_limit > Provider.quota_used) | |||||
| # Prepare values dict for SQLAlchemy update | |||||
| update_values = {} | |||||
| if values.last_used is not None: | |||||
| update_values["last_used"] = values.last_used | |||||
| if values.quota_used is not None: | |||||
| update_values["quota_used"] = values.quota_used | |||||
| # Build and execute the update statement | |||||
| stmt = update(Provider).where(*where_conditions).values(**update_values) | |||||
| result = db.session.execute(stmt) | |||||
| rows_affected = result.rowcount | |||||
| logger.debug( | |||||
| f"Provider update ({description}): {rows_affected} rows affected. " | |||||
| f"Filters: {filters.model_dump()}, Values: {update_values}" | |||||
| ) | |||||
| # If no rows were affected for quota updates, log a warning | |||||
| if rows_affected == 0 and description == "quota_deduction_update": | |||||
| logger.warning( | |||||
| f"No Provider rows updated for quota deduction. " | |||||
| f"This may indicate quota limit exceeded or provider not found. " | |||||
| f"Filters: {filters.model_dump()}" | |||||
| ) | |||||
| logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates") |
| _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) | _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) | ||||
| query: Mapped[str] = db.Column(db.Text, nullable=False) | query: Mapped[str] = db.Column(db.Text, nullable=False) | ||||
| message = db.Column(db.JSON, nullable=False) | message = db.Column(db.JSON, nullable=False) | ||||
| message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) | |||||
| message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0")) | |||||
| message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | ||||
| message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) | message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) | ||||
| answer: Mapped[str] = db.Column(db.Text, nullable=False) | answer: Mapped[str] = db.Column(db.Text, nullable=False) | ||||
| answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) | |||||
| answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0")) | |||||
| answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | ||||
| answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) | answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) | ||||
| parent_message_id = db.Column(StringUUID, nullable=True) | parent_message_id = db.Column(StringUUID, nullable=True) |
| "types_setuptools>=80.9.0", | "types_setuptools>=80.9.0", | ||||
| "pandas-stubs~=2.2.3", | "pandas-stubs~=2.2.3", | ||||
| "scipy-stubs>=1.15.3.0", | "scipy-stubs>=1.15.3.0", | ||||
| "types-python-http-client>=3.3.7.20240910", | |||||
| ] | ] | ||||
| ############################################################ | ############################################################ |
| import threading | |||||
| from unittest.mock import Mock, patch | |||||
| from core.app.entities.app_invoke_entities import ChatAppGenerateEntity | |||||
| from core.entities.provider_entities import QuotaUnit | |||||
| from events.event_handlers.update_provider_when_message_created import ( | |||||
| handle, | |||||
| get_update_stats, | |||||
| ) | |||||
| from models.provider import ProviderType | |||||
| from sqlalchemy.exc import OperationalError | |||||
| class TestProviderUpdateDeadlockPrevention: | |||||
| """Test suite for deadlock prevention in Provider updates.""" | |||||
| def setup_method(self): | |||||
| """Setup test fixtures.""" | |||||
| self.mock_message = Mock() | |||||
| self.mock_message.answer_tokens = 100 | |||||
| self.mock_app_config = Mock() | |||||
| self.mock_app_config.tenant_id = "test-tenant-123" | |||||
| self.mock_model_conf = Mock() | |||||
| self.mock_model_conf.provider = "openai" | |||||
| self.mock_system_config = Mock() | |||||
| self.mock_system_config.current_quota_type = QuotaUnit.TOKENS | |||||
| self.mock_provider_config = Mock() | |||||
| self.mock_provider_config.using_provider_type = ProviderType.SYSTEM | |||||
| self.mock_provider_config.system_configuration = self.mock_system_config | |||||
| self.mock_provider_bundle = Mock() | |||||
| self.mock_provider_bundle.configuration = self.mock_provider_config | |||||
| self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle | |||||
| self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity) | |||||
| self.mock_generate_entity.app_config = self.mock_app_config | |||||
| self.mock_generate_entity.model_conf = self.mock_model_conf | |||||
| @patch("events.event_handlers.update_provider_when_message_created.db") | |||||
| def test_consolidated_handler_basic_functionality(self, mock_db): | |||||
| """Test that the consolidated handler performs both updates correctly.""" | |||||
| # Setup mock query chain | |||||
| mock_query = Mock() | |||||
| mock_db.session.query.return_value = mock_query | |||||
| mock_query.filter.return_value = mock_query | |||||
| mock_query.order_by.return_value = mock_query | |||||
| mock_query.update.return_value = 1 # 1 row affected | |||||
| # Call the handler | |||||
| handle(self.mock_message, application_generate_entity=self.mock_generate_entity) | |||||
| # Verify db.session.query was called | |||||
| assert mock_db.session.query.called | |||||
| # Verify commit was called | |||||
| mock_db.session.commit.assert_called_once() | |||||
| # Verify no rollback was called | |||||
| assert not mock_db.session.rollback.called | |||||
| @patch("events.event_handlers.update_provider_when_message_created.db") | |||||
| def test_deadlock_retry_mechanism(self, mock_db): | |||||
| """Test that deadlock errors trigger retry logic.""" | |||||
| # Setup mock to raise deadlock error on first attempt, succeed on second | |||||
| mock_query = Mock() | |||||
| mock_db.session.query.return_value = mock_query | |||||
| mock_query.filter.return_value = mock_query | |||||
| mock_query.order_by.return_value = mock_query | |||||
| mock_query.update.return_value = 1 | |||||
| # First call raises deadlock, second succeeds | |||||
| mock_db.session.commit.side_effect = [ | |||||
| OperationalError("deadlock detected", None, None), | |||||
| None, # Success on retry | |||||
| ] | |||||
| # Call the handler | |||||
| handle(self.mock_message, application_generate_entity=self.mock_generate_entity) | |||||
| # Verify commit was called twice (original + retry) | |||||
| assert mock_db.session.commit.call_count == 2 | |||||
| # Verify rollback was called once (after first failure) | |||||
| mock_db.session.rollback.assert_called_once() | |||||
| @patch("events.event_handlers.update_provider_when_message_created.db") | |||||
| @patch("events.event_handlers.update_provider_when_message_created.time.sleep") | |||||
| def test_exponential_backoff_timing(self, mock_sleep, mock_db): | |||||
| """Test that retry delays follow exponential backoff pattern.""" | |||||
| # Setup mock to fail twice, succeed on third attempt | |||||
| mock_query = Mock() | |||||
| mock_db.session.query.return_value = mock_query | |||||
| mock_query.filter.return_value = mock_query | |||||
| mock_query.order_by.return_value = mock_query | |||||
| mock_query.update.return_value = 1 | |||||
| mock_db.session.commit.side_effect = [ | |||||
| OperationalError("deadlock detected", None, None), | |||||
| OperationalError("deadlock detected", None, None), | |||||
| None, # Success on third attempt | |||||
| ] | |||||
| # Call the handler | |||||
| handle(self.mock_message, application_generate_entity=self.mock_generate_entity) | |||||
| # Verify sleep was called twice with increasing delays | |||||
| assert mock_sleep.call_count == 2 | |||||
| # First delay should be around 0.1s + jitter | |||||
| first_delay = mock_sleep.call_args_list[0][0][0] | |||||
| assert 0.1 <= first_delay <= 0.3 | |||||
| # Second delay should be around 0.2s + jitter | |||||
| second_delay = mock_sleep.call_args_list[1][0][0] | |||||
| assert 0.2 <= second_delay <= 0.4 | |||||
| def test_concurrent_handler_execution(self): | |||||
| """Test that multiple handlers can run concurrently without deadlock.""" | |||||
| results = [] | |||||
| errors = [] | |||||
| def run_handler(): | |||||
| try: | |||||
| with patch( | |||||
| "events.event_handlers.update_provider_when_message_created.db" | |||||
| ) as mock_db: | |||||
| mock_query = Mock() | |||||
| mock_db.session.query.return_value = mock_query | |||||
| mock_query.filter.return_value = mock_query | |||||
| mock_query.order_by.return_value = mock_query | |||||
| mock_query.update.return_value = 1 | |||||
| handle( | |||||
| self.mock_message, | |||||
| application_generate_entity=self.mock_generate_entity, | |||||
| ) | |||||
| results.append("success") | |||||
| except Exception as e: | |||||
| errors.append(str(e)) | |||||
| # Run multiple handlers concurrently | |||||
| threads = [] | |||||
| for _ in range(5): | |||||
| thread = threading.Thread(target=run_handler) | |||||
| threads.append(thread) | |||||
| thread.start() | |||||
| # Wait for all threads to complete | |||||
| for thread in threads: | |||||
| thread.join(timeout=5) | |||||
| # Verify all handlers completed successfully | |||||
| assert len(results) == 5 | |||||
| assert len(errors) == 0 | |||||
| def test_performance_stats_tracking(self): | |||||
| """Test that performance statistics are tracked correctly.""" | |||||
| # Reset stats | |||||
| stats = get_update_stats() | |||||
| initial_total = stats["total_updates"] | |||||
| with patch( | |||||
| "events.event_handlers.update_provider_when_message_created.db" | |||||
| ) as mock_db: | |||||
| mock_query = Mock() | |||||
| mock_db.session.query.return_value = mock_query | |||||
| mock_query.filter.return_value = mock_query | |||||
| mock_query.order_by.return_value = mock_query | |||||
| mock_query.update.return_value = 1 | |||||
| # Call handler | |||||
| handle( | |||||
| self.mock_message, application_generate_entity=self.mock_generate_entity | |||||
| ) | |||||
| # Check that stats were updated | |||||
| updated_stats = get_update_stats() | |||||
| assert updated_stats["total_updates"] == initial_total + 1 | |||||
| assert updated_stats["successful_updates"] >= initial_total + 1 | |||||
| def test_non_chat_entity_ignored(self): | |||||
| """Test that non-chat entities are ignored by the handler.""" | |||||
| # Create a non-chat entity | |||||
| mock_non_chat_entity = Mock() | |||||
| mock_non_chat_entity.__class__.__name__ = "NonChatEntity" | |||||
| with patch( | |||||
| "events.event_handlers.update_provider_when_message_created.db" | |||||
| ) as mock_db: | |||||
| # Call handler with non-chat entity | |||||
| handle(self.mock_message, application_generate_entity=mock_non_chat_entity) | |||||
| # Verify no database operations were performed | |||||
| assert not mock_db.session.query.called | |||||
| assert not mock_db.session.commit.called | |||||
| @patch("events.event_handlers.update_provider_when_message_created.db") | |||||
| def test_quota_calculation_tokens(self, mock_db): | |||||
| """Test quota calculation for token-based quotas.""" | |||||
| # Setup token-based quota | |||||
| self.mock_system_config.current_quota_type = QuotaUnit.TOKENS | |||||
| self.mock_message.answer_tokens = 150 | |||||
| mock_query = Mock() | |||||
| mock_db.session.query.return_value = mock_query | |||||
| mock_query.filter.return_value = mock_query | |||||
| mock_query.order_by.return_value = mock_query | |||||
| mock_query.update.return_value = 1 | |||||
| # Call handler | |||||
| handle(self.mock_message, application_generate_entity=self.mock_generate_entity) | |||||
| # Verify update was called with token count | |||||
| update_calls = mock_query.update.call_args_list | |||||
| # Should have at least one call with quota_used update | |||||
| quota_update_found = False | |||||
| for call in update_calls: | |||||
| values = call[0][0] # First argument to update() | |||||
| if "quota_used" in values: | |||||
| quota_update_found = True | |||||
| break | |||||
| assert quota_update_found | |||||
| @patch("events.event_handlers.update_provider_when_message_created.db") | |||||
| def test_quota_calculation_times(self, mock_db): | |||||
| """Test quota calculation for times-based quotas.""" | |||||
| # Setup times-based quota | |||||
| self.mock_system_config.current_quota_type = QuotaUnit.TIMES | |||||
| mock_query = Mock() | |||||
| mock_db.session.query.return_value = mock_query | |||||
| mock_query.filter.return_value = mock_query | |||||
| mock_query.order_by.return_value = mock_query | |||||
| mock_query.update.return_value = 1 | |||||
| # Call handler | |||||
| handle(self.mock_message, application_generate_entity=self.mock_generate_entity) | |||||
| # Verify update was called | |||||
| assert mock_query.update.called | |||||
| assert mock_db.session.commit.called |