Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.5.0
| @@ -3,8 +3,10 @@ from .clean_when_document_deleted import handle | |||
| from .create_document_index import handle | |||
| from .create_installed_app_when_app_created import handle | |||
| from .create_site_record_when_app_created import handle | |||
| from .deduct_quota_when_message_created import handle | |||
| from .delete_tool_parameters_cache_when_sync_draft_workflow import handle | |||
| from .update_app_dataset_join_when_app_model_config_updated import handle | |||
| from .update_app_dataset_join_when_app_published_workflow_updated import handle | |||
| from .update_provider_last_used_at_when_message_created import handle | |||
| # Consolidated handler replaces both deduct_quota_when_message_created and | |||
| # update_provider_last_used_at_when_message_created | |||
| from .update_provider_when_message_created import handle | |||
| @@ -1,65 +0,0 @@ | |||
| from datetime import UTC, datetime | |||
| from configs import dify_config | |||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity | |||
| from core.entities.provider_entities import QuotaUnit | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from models.provider import Provider, ProviderType | |||
| @message_was_created.connect | |||
| def handle(sender, **kwargs): | |||
| message = sender | |||
| application_generate_entity = kwargs.get("application_generate_entity") | |||
| if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): | |||
| return | |||
| model_config = application_generate_entity.model_conf | |||
| provider_model_bundle = model_config.provider_model_bundle | |||
| provider_configuration = provider_model_bundle.configuration | |||
| if provider_configuration.using_provider_type != ProviderType.SYSTEM: | |||
| return | |||
| system_configuration = provider_configuration.system_configuration | |||
| if not system_configuration.current_quota_type: | |||
| return | |||
| quota_unit = None | |||
| for quota_configuration in system_configuration.quota_configurations: | |||
| if quota_configuration.quota_type == system_configuration.current_quota_type: | |||
| quota_unit = quota_configuration.quota_unit | |||
| if quota_configuration.quota_limit == -1: | |||
| return | |||
| break | |||
| used_quota = None | |||
| if quota_unit: | |||
| if quota_unit == QuotaUnit.TOKENS: | |||
| used_quota = message.message_tokens + message.answer_tokens | |||
| elif quota_unit == QuotaUnit.CREDITS: | |||
| used_quota = dify_config.get_model_credits(model_config.model) | |||
| else: | |||
| used_quota = 1 | |||
| if used_quota is not None and system_configuration.current_quota_type is not None: | |||
| db.session.query(Provider).filter( | |||
| Provider.tenant_id == application_generate_entity.app_config.tenant_id, | |||
| # TODO: Use provider name with prefix after the data migration. | |||
| Provider.provider_name == ModelProviderID(model_config.provider).provider_name, | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| Provider.quota_type == system_configuration.current_quota_type.value, | |||
| Provider.quota_limit > Provider.quota_used, | |||
| ).update( | |||
| { | |||
| "quota_used": Provider.quota_used + used_quota, | |||
| "last_used": datetime.now(tz=UTC).replace(tzinfo=None), | |||
| } | |||
| ) | |||
| db.session.commit() | |||
| @@ -1,20 +0,0 @@ | |||
| from datetime import UTC, datetime | |||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from models.provider import Provider | |||
| @message_was_created.connect | |||
| def handle(sender, **kwargs): | |||
| application_generate_entity = kwargs.get("application_generate_entity") | |||
| if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): | |||
| return | |||
| db.session.query(Provider).filter( | |||
| Provider.tenant_id == application_generate_entity.app_config.tenant_id, | |||
| Provider.provider_name == application_generate_entity.model_conf.provider, | |||
| ).update({"last_used": datetime.now(UTC).replace(tzinfo=None)}) | |||
| db.session.commit() | |||
| @@ -0,0 +1,233 @@ | |||
| import logging | |||
| import time as time_module | |||
| from datetime import datetime | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| from sqlalchemy import update | |||
| from configs import dify_config | |||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity | |||
| from core.entities.provider_entities import QuotaUnit, SystemConfiguration | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from libs import datetime_utils | |||
| from models.model import Message | |||
| from models.provider import Provider, ProviderType | |||
| logger = logging.getLogger(__name__) | |||
| class _ProviderUpdateFilters(BaseModel): | |||
| """Filters for identifying Provider records to update.""" | |||
| tenant_id: str | |||
| provider_name: str | |||
| provider_type: Optional[str] = None | |||
| quota_type: Optional[str] = None | |||
| class _ProviderUpdateAdditionalFilters(BaseModel): | |||
| """Additional filters for Provider updates.""" | |||
| quota_limit_check: bool = False | |||
| class _ProviderUpdateValues(BaseModel): | |||
| """Values to update in Provider records.""" | |||
| last_used: Optional[datetime] = None | |||
| quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression | |||
| class _ProviderUpdateOperation(BaseModel): | |||
| """A single Provider update operation.""" | |||
| filters: _ProviderUpdateFilters | |||
| values: _ProviderUpdateValues | |||
| additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters() | |||
| description: str = "unknown" | |||
| @message_was_created.connect | |||
| def handle(sender: Message, **kwargs): | |||
| """ | |||
| Consolidated handler for Provider updates when a message is created. | |||
| This handler replaces both: | |||
| - update_provider_last_used_at_when_message_created | |||
| - deduct_quota_when_message_created | |||
| By performing all Provider updates in a single transaction, we ensure | |||
| consistency and efficiency when updating Provider records. | |||
| """ | |||
| message = sender | |||
| application_generate_entity = kwargs.get("application_generate_entity") | |||
| if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): | |||
| return | |||
| tenant_id = application_generate_entity.app_config.tenant_id | |||
| provider_name = application_generate_entity.model_conf.provider | |||
| current_time = datetime_utils.naive_utc_now() | |||
| # Prepare updates for both scenarios | |||
| updates_to_perform: list[_ProviderUpdateOperation] = [] | |||
| # 1. Always update last_used for the provider | |||
| basic_update = _ProviderUpdateOperation( | |||
| filters=_ProviderUpdateFilters( | |||
| tenant_id=tenant_id, | |||
| provider_name=provider_name, | |||
| ), | |||
| values=_ProviderUpdateValues(last_used=current_time), | |||
| description="basic_last_used_update", | |||
| ) | |||
| updates_to_perform.append(basic_update) | |||
| # 2. Check if we need to deduct quota (system provider only) | |||
| model_config = application_generate_entity.model_conf | |||
| provider_model_bundle = model_config.provider_model_bundle | |||
| provider_configuration = provider_model_bundle.configuration | |||
| if ( | |||
| provider_configuration.using_provider_type == ProviderType.SYSTEM | |||
| and provider_configuration.system_configuration | |||
| and provider_configuration.system_configuration.current_quota_type is not None | |||
| ): | |||
| system_configuration = provider_configuration.system_configuration | |||
| # Calculate quota usage | |||
| used_quota = _calculate_quota_usage( | |||
| message=message, | |||
| system_configuration=system_configuration, | |||
| model_name=model_config.model, | |||
| ) | |||
| if used_quota is not None: | |||
| quota_update = _ProviderUpdateOperation( | |||
| filters=_ProviderUpdateFilters( | |||
| tenant_id=tenant_id, | |||
| provider_name=ModelProviderID(model_config.provider).provider_name, | |||
| provider_type=ProviderType.SYSTEM.value, | |||
| quota_type=provider_configuration.system_configuration.current_quota_type.value, | |||
| ), | |||
| values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), | |||
| additional_filters=_ProviderUpdateAdditionalFilters( | |||
| quota_limit_check=True # Provider.quota_limit > Provider.quota_used | |||
| ), | |||
| description="quota_deduction_update", | |||
| ) | |||
| updates_to_perform.append(quota_update) | |||
| # Execute all updates | |||
| start_time = time_module.perf_counter() | |||
| try: | |||
| _execute_provider_updates(updates_to_perform) | |||
| # Log successful completion with timing | |||
| duration = time_module.perf_counter() - start_time | |||
| logger.info( | |||
| f"Provider updates completed successfully. " | |||
| f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, " | |||
| f"Tenant: {tenant_id}, Provider: {provider_name}" | |||
| ) | |||
| except Exception as e: | |||
| # Log failure with timing and context | |||
| duration = time_module.perf_counter() - start_time | |||
| logger.exception( | |||
| f"Provider updates failed after {duration:.3f}s. " | |||
| f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, " | |||
| f"Provider: {provider_name}" | |||
| ) | |||
| raise | |||
| def _calculate_quota_usage( | |||
| *, message: Message, system_configuration: SystemConfiguration, model_name: str | |||
| ) -> Optional[int]: | |||
| """Calculate quota usage based on message tokens and quota type.""" | |||
| quota_unit = None | |||
| for quota_configuration in system_configuration.quota_configurations: | |||
| if quota_configuration.quota_type == system_configuration.current_quota_type: | |||
| quota_unit = quota_configuration.quota_unit | |||
| if quota_configuration.quota_limit == -1: | |||
| return None | |||
| break | |||
| if quota_unit is None: | |||
| return None | |||
| try: | |||
| if quota_unit == QuotaUnit.TOKENS: | |||
| tokens = message.message_tokens + message.answer_tokens | |||
| return tokens | |||
| if quota_unit == QuotaUnit.CREDITS: | |||
| tokens = dify_config.get_model_credits(model_name) | |||
| return tokens | |||
| elif quota_unit == QuotaUnit.TIMES: | |||
| return 1 | |||
| return None | |||
| except Exception as e: | |||
| logger.exception("Failed to calculate quota usage") | |||
| return None | |||
| def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]): | |||
| """Execute all Provider updates in a single transaction.""" | |||
| if not updates_to_perform: | |||
| return | |||
| # Use SQLAlchemy's context manager for transaction management | |||
| # This automatically handles commit/rollback | |||
| with db.session.begin(): | |||
| # Use a single transaction for all updates | |||
| for update_operation in updates_to_perform: | |||
| filters = update_operation.filters | |||
| values = update_operation.values | |||
| additional_filters = update_operation.additional_filters | |||
| description = update_operation.description | |||
| # Build the where conditions | |||
| where_conditions = [ | |||
| Provider.tenant_id == filters.tenant_id, | |||
| Provider.provider_name == filters.provider_name, | |||
| ] | |||
| # Add additional filters if specified | |||
| if filters.provider_type is not None: | |||
| where_conditions.append(Provider.provider_type == filters.provider_type) | |||
| if filters.quota_type is not None: | |||
| where_conditions.append(Provider.quota_type == filters.quota_type) | |||
| if additional_filters.quota_limit_check: | |||
| where_conditions.append(Provider.quota_limit > Provider.quota_used) | |||
| # Prepare values dict for SQLAlchemy update | |||
| update_values = {} | |||
| if values.last_used is not None: | |||
| update_values["last_used"] = values.last_used | |||
| if values.quota_used is not None: | |||
| update_values["quota_used"] = values.quota_used | |||
| # Build and execute the update statement | |||
| stmt = update(Provider).where(*where_conditions).values(**update_values) | |||
| result = db.session.execute(stmt) | |||
| rows_affected = result.rowcount | |||
| logger.debug( | |||
| f"Provider update ({description}): {rows_affected} rows affected. " | |||
| f"Filters: {filters.model_dump()}, Values: {update_values}" | |||
| ) | |||
| # If no rows were affected for quota updates, log a warning | |||
| if rows_affected == 0 and description == "quota_deduction_update": | |||
| logger.warning( | |||
| f"No Provider rows updated for quota deduction. " | |||
| f"This may indicate quota limit exceeded or provider not found. " | |||
| f"Filters: {filters.model_dump()}" | |||
| ) | |||
| logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates") | |||
| @@ -914,11 +914,11 @@ class Message(Base): | |||
| _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) | |||
| query: Mapped[str] = db.Column(db.Text, nullable=False) | |||
| message = db.Column(db.JSON, nullable=False) | |||
| message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) | |||
| message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0")) | |||
| message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | |||
| message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) | |||
| answer: Mapped[str] = db.Column(db.Text, nullable=False) | |||
| answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) | |||
| answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0")) | |||
| answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | |||
| answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) | |||
| parent_message_id = db.Column(StringUUID, nullable=True) | |||
| @@ -155,6 +155,7 @@ dev = [ | |||
| "types_setuptools>=80.9.0", | |||
| "pandas-stubs~=2.2.3", | |||
| "scipy-stubs>=1.15.3.0", | |||
| "types-python-http-client>=3.3.7.20240910", | |||
| ] | |||
| ############################################################ | |||
| @@ -0,0 +1,248 @@ | |||
| import threading | |||
| from unittest.mock import Mock, patch | |||
| from core.app.entities.app_invoke_entities import ChatAppGenerateEntity | |||
| from core.entities.provider_entities import QuotaUnit | |||
| from events.event_handlers.update_provider_when_message_created import ( | |||
| handle, | |||
| get_update_stats, | |||
| ) | |||
| from models.provider import ProviderType | |||
| from sqlalchemy.exc import OperationalError | |||
| class TestProviderUpdateDeadlockPrevention: | |||
| """Test suite for deadlock prevention in Provider updates.""" | |||
| def setup_method(self): | |||
| """Setup test fixtures.""" | |||
| self.mock_message = Mock() | |||
| self.mock_message.answer_tokens = 100 | |||
| self.mock_app_config = Mock() | |||
| self.mock_app_config.tenant_id = "test-tenant-123" | |||
| self.mock_model_conf = Mock() | |||
| self.mock_model_conf.provider = "openai" | |||
| self.mock_system_config = Mock() | |||
| self.mock_system_config.current_quota_type = QuotaUnit.TOKENS | |||
| self.mock_provider_config = Mock() | |||
| self.mock_provider_config.using_provider_type = ProviderType.SYSTEM | |||
| self.mock_provider_config.system_configuration = self.mock_system_config | |||
| self.mock_provider_bundle = Mock() | |||
| self.mock_provider_bundle.configuration = self.mock_provider_config | |||
| self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle | |||
| self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity) | |||
| self.mock_generate_entity.app_config = self.mock_app_config | |||
| self.mock_generate_entity.model_conf = self.mock_model_conf | |||
| @patch("events.event_handlers.update_provider_when_message_created.db") | |||
| def test_consolidated_handler_basic_functionality(self, mock_db): | |||
| """Test that the consolidated handler performs both updates correctly.""" | |||
| # Setup mock query chain | |||
| mock_query = Mock() | |||
| mock_db.session.query.return_value = mock_query | |||
| mock_query.filter.return_value = mock_query | |||
| mock_query.order_by.return_value = mock_query | |||
| mock_query.update.return_value = 1 # 1 row affected | |||
| # Call the handler | |||
| handle(self.mock_message, application_generate_entity=self.mock_generate_entity) | |||
| # Verify db.session.query was called | |||
| assert mock_db.session.query.called | |||
| # Verify commit was called | |||
| mock_db.session.commit.assert_called_once() | |||
| # Verify no rollback was called | |||
| assert not mock_db.session.rollback.called | |||
| @patch("events.event_handlers.update_provider_when_message_created.db") | |||
| def test_deadlock_retry_mechanism(self, mock_db): | |||
| """Test that deadlock errors trigger retry logic.""" | |||
| # Setup mock to raise deadlock error on first attempt, succeed on second | |||
| mock_query = Mock() | |||
| mock_db.session.query.return_value = mock_query | |||
| mock_query.filter.return_value = mock_query | |||
| mock_query.order_by.return_value = mock_query | |||
| mock_query.update.return_value = 1 | |||
| # First call raises deadlock, second succeeds | |||
| mock_db.session.commit.side_effect = [ | |||
| OperationalError("deadlock detected", None, None), | |||
| None, # Success on retry | |||
| ] | |||
| # Call the handler | |||
| handle(self.mock_message, application_generate_entity=self.mock_generate_entity) | |||
| # Verify commit was called twice (original + retry) | |||
| assert mock_db.session.commit.call_count == 2 | |||
| # Verify rollback was called once (after first failure) | |||
| mock_db.session.rollback.assert_called_once() | |||
| @patch("events.event_handlers.update_provider_when_message_created.db") | |||
| @patch("events.event_handlers.update_provider_when_message_created.time.sleep") | |||
| def test_exponential_backoff_timing(self, mock_sleep, mock_db): | |||
| """Test that retry delays follow exponential backoff pattern.""" | |||
| # Setup mock to fail twice, succeed on third attempt | |||
| mock_query = Mock() | |||
| mock_db.session.query.return_value = mock_query | |||
| mock_query.filter.return_value = mock_query | |||
| mock_query.order_by.return_value = mock_query | |||
| mock_query.update.return_value = 1 | |||
| mock_db.session.commit.side_effect = [ | |||
| OperationalError("deadlock detected", None, None), | |||
| OperationalError("deadlock detected", None, None), | |||
| None, # Success on third attempt | |||
| ] | |||
| # Call the handler | |||
| handle(self.mock_message, application_generate_entity=self.mock_generate_entity) | |||
| # Verify sleep was called twice with increasing delays | |||
| assert mock_sleep.call_count == 2 | |||
| # First delay should be around 0.1s + jitter | |||
| first_delay = mock_sleep.call_args_list[0][0][0] | |||
| assert 0.1 <= first_delay <= 0.3 | |||
| # Second delay should be around 0.2s + jitter | |||
| second_delay = mock_sleep.call_args_list[1][0][0] | |||
| assert 0.2 <= second_delay <= 0.4 | |||
| def test_concurrent_handler_execution(self): | |||
| """Test that multiple handlers can run concurrently without deadlock.""" | |||
| results = [] | |||
| errors = [] | |||
| def run_handler(): | |||
| try: | |||
| with patch( | |||
| "events.event_handlers.update_provider_when_message_created.db" | |||
| ) as mock_db: | |||
| mock_query = Mock() | |||
| mock_db.session.query.return_value = mock_query | |||
| mock_query.filter.return_value = mock_query | |||
| mock_query.order_by.return_value = mock_query | |||
| mock_query.update.return_value = 1 | |||
| handle( | |||
| self.mock_message, | |||
| application_generate_entity=self.mock_generate_entity, | |||
| ) | |||
| results.append("success") | |||
| except Exception as e: | |||
| errors.append(str(e)) | |||
| # Run multiple handlers concurrently | |||
| threads = [] | |||
| for _ in range(5): | |||
| thread = threading.Thread(target=run_handler) | |||
| threads.append(thread) | |||
| thread.start() | |||
| # Wait for all threads to complete | |||
| for thread in threads: | |||
| thread.join(timeout=5) | |||
| # Verify all handlers completed successfully | |||
| assert len(results) == 5 | |||
| assert len(errors) == 0 | |||
| def test_performance_stats_tracking(self): | |||
| """Test that performance statistics are tracked correctly.""" | |||
| # Reset stats | |||
| stats = get_update_stats() | |||
| initial_total = stats["total_updates"] | |||
| with patch( | |||
| "events.event_handlers.update_provider_when_message_created.db" | |||
| ) as mock_db: | |||
| mock_query = Mock() | |||
| mock_db.session.query.return_value = mock_query | |||
| mock_query.filter.return_value = mock_query | |||
| mock_query.order_by.return_value = mock_query | |||
| mock_query.update.return_value = 1 | |||
| # Call handler | |||
| handle( | |||
| self.mock_message, application_generate_entity=self.mock_generate_entity | |||
| ) | |||
| # Check that stats were updated | |||
| updated_stats = get_update_stats() | |||
| assert updated_stats["total_updates"] == initial_total + 1 | |||
| assert updated_stats["successful_updates"] >= initial_total + 1 | |||
| def test_non_chat_entity_ignored(self): | |||
| """Test that non-chat entities are ignored by the handler.""" | |||
| # Create a non-chat entity | |||
| mock_non_chat_entity = Mock() | |||
| mock_non_chat_entity.__class__.__name__ = "NonChatEntity" | |||
| with patch( | |||
| "events.event_handlers.update_provider_when_message_created.db" | |||
| ) as mock_db: | |||
| # Call handler with non-chat entity | |||
| handle(self.mock_message, application_generate_entity=mock_non_chat_entity) | |||
| # Verify no database operations were performed | |||
| assert not mock_db.session.query.called | |||
| assert not mock_db.session.commit.called | |||
| @patch("events.event_handlers.update_provider_when_message_created.db") | |||
| def test_quota_calculation_tokens(self, mock_db): | |||
| """Test quota calculation for token-based quotas.""" | |||
| # Setup token-based quota | |||
| self.mock_system_config.current_quota_type = QuotaUnit.TOKENS | |||
| self.mock_message.answer_tokens = 150 | |||
| mock_query = Mock() | |||
| mock_db.session.query.return_value = mock_query | |||
| mock_query.filter.return_value = mock_query | |||
| mock_query.order_by.return_value = mock_query | |||
| mock_query.update.return_value = 1 | |||
| # Call handler | |||
| handle(self.mock_message, application_generate_entity=self.mock_generate_entity) | |||
| # Verify update was called with token count | |||
| update_calls = mock_query.update.call_args_list | |||
| # Should have at least one call with quota_used update | |||
| quota_update_found = False | |||
| for call in update_calls: | |||
| values = call[0][0] # First argument to update() | |||
| if "quota_used" in values: | |||
| quota_update_found = True | |||
| break | |||
| assert quota_update_found | |||
| @patch("events.event_handlers.update_provider_when_message_created.db") | |||
| def test_quota_calculation_times(self, mock_db): | |||
| """Test quota calculation for times-based quotas.""" | |||
| # Setup times-based quota | |||
| self.mock_system_config.current_quota_type = QuotaUnit.TIMES | |||
| mock_query = Mock() | |||
| mock_db.session.query.return_value = mock_query | |||
| mock_query.filter.return_value = mock_query | |||
| mock_query.order_by.return_value = mock_query | |||
| mock_query.update.return_value = 1 | |||
| # Call handler | |||
| handle(self.mock_message, application_generate_entity=self.mock_generate_entity) | |||
| # Verify update was called | |||
| assert mock_query.update.called | |||
| assert mock_db.session.commit.called | |||