|
|
|
@@ -1,248 +0,0 @@ |
|
|
|
import threading |
|
|
|
from unittest.mock import Mock, patch |
|
|
|
|
|
|
|
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity |
|
|
|
from core.entities.provider_entities import QuotaUnit |
|
|
|
from events.event_handlers.update_provider_when_message_created import ( |
|
|
|
handle, |
|
|
|
get_update_stats, |
|
|
|
) |
|
|
|
from models.provider import ProviderType |
|
|
|
from sqlalchemy.exc import OperationalError |
|
|
|
|
|
|
|
|
|
|
|
class TestProviderUpdateDeadlockPrevention: |
|
|
|
"""Test suite for deadlock prevention in Provider updates.""" |
|
|
|
|
|
|
|
def setup_method(self): |
|
|
|
"""Setup test fixtures.""" |
|
|
|
self.mock_message = Mock() |
|
|
|
self.mock_message.answer_tokens = 100 |
|
|
|
|
|
|
|
self.mock_app_config = Mock() |
|
|
|
self.mock_app_config.tenant_id = "test-tenant-123" |
|
|
|
|
|
|
|
self.mock_model_conf = Mock() |
|
|
|
self.mock_model_conf.provider = "openai" |
|
|
|
|
|
|
|
self.mock_system_config = Mock() |
|
|
|
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS |
|
|
|
|
|
|
|
self.mock_provider_config = Mock() |
|
|
|
self.mock_provider_config.using_provider_type = ProviderType.SYSTEM |
|
|
|
self.mock_provider_config.system_configuration = self.mock_system_config |
|
|
|
|
|
|
|
self.mock_provider_bundle = Mock() |
|
|
|
self.mock_provider_bundle.configuration = self.mock_provider_config |
|
|
|
|
|
|
|
self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle |
|
|
|
|
|
|
|
self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity) |
|
|
|
self.mock_generate_entity.app_config = self.mock_app_config |
|
|
|
self.mock_generate_entity.model_conf = self.mock_model_conf |
|
|
|
|
|
|
|
@patch("events.event_handlers.update_provider_when_message_created.db") |
|
|
|
def test_consolidated_handler_basic_functionality(self, mock_db): |
|
|
|
"""Test that the consolidated handler performs both updates correctly.""" |
|
|
|
# Setup mock query chain |
|
|
|
mock_query = Mock() |
|
|
|
mock_db.session.query.return_value = mock_query |
|
|
|
mock_query.filter.return_value = mock_query |
|
|
|
mock_query.order_by.return_value = mock_query |
|
|
|
mock_query.update.return_value = 1 # 1 row affected |
|
|
|
|
|
|
|
# Call the handler |
|
|
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity) |
|
|
|
|
|
|
|
# Verify db.session.query was called |
|
|
|
assert mock_db.session.query.called |
|
|
|
|
|
|
|
# Verify commit was called |
|
|
|
mock_db.session.commit.assert_called_once() |
|
|
|
|
|
|
|
# Verify no rollback was called |
|
|
|
assert not mock_db.session.rollback.called |
|
|
|
|
|
|
|
@patch("events.event_handlers.update_provider_when_message_created.db") |
|
|
|
def test_deadlock_retry_mechanism(self, mock_db): |
|
|
|
"""Test that deadlock errors trigger retry logic.""" |
|
|
|
# Setup mock to raise deadlock error on first attempt, succeed on second |
|
|
|
mock_query = Mock() |
|
|
|
mock_db.session.query.return_value = mock_query |
|
|
|
mock_query.filter.return_value = mock_query |
|
|
|
mock_query.order_by.return_value = mock_query |
|
|
|
mock_query.update.return_value = 1 |
|
|
|
|
|
|
|
# First call raises deadlock, second succeeds |
|
|
|
mock_db.session.commit.side_effect = [ |
|
|
|
OperationalError("deadlock detected", None, None), |
|
|
|
None, # Success on retry |
|
|
|
] |
|
|
|
|
|
|
|
# Call the handler |
|
|
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity) |
|
|
|
|
|
|
|
# Verify commit was called twice (original + retry) |
|
|
|
assert mock_db.session.commit.call_count == 2 |
|
|
|
|
|
|
|
# Verify rollback was called once (after first failure) |
|
|
|
mock_db.session.rollback.assert_called_once() |
|
|
|
|
|
|
|
@patch("events.event_handlers.update_provider_when_message_created.db") |
|
|
|
@patch("events.event_handlers.update_provider_when_message_created.time.sleep") |
|
|
|
def test_exponential_backoff_timing(self, mock_sleep, mock_db): |
|
|
|
"""Test that retry delays follow exponential backoff pattern.""" |
|
|
|
# Setup mock to fail twice, succeed on third attempt |
|
|
|
mock_query = Mock() |
|
|
|
mock_db.session.query.return_value = mock_query |
|
|
|
mock_query.filter.return_value = mock_query |
|
|
|
mock_query.order_by.return_value = mock_query |
|
|
|
mock_query.update.return_value = 1 |
|
|
|
|
|
|
|
mock_db.session.commit.side_effect = [ |
|
|
|
OperationalError("deadlock detected", None, None), |
|
|
|
OperationalError("deadlock detected", None, None), |
|
|
|
None, # Success on third attempt |
|
|
|
] |
|
|
|
|
|
|
|
# Call the handler |
|
|
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity) |
|
|
|
|
|
|
|
# Verify sleep was called twice with increasing delays |
|
|
|
assert mock_sleep.call_count == 2 |
|
|
|
|
|
|
|
# First delay should be around 0.1s + jitter |
|
|
|
first_delay = mock_sleep.call_args_list[0][0][0] |
|
|
|
assert 0.1 <= first_delay <= 0.3 |
|
|
|
|
|
|
|
# Second delay should be around 0.2s + jitter |
|
|
|
second_delay = mock_sleep.call_args_list[1][0][0] |
|
|
|
assert 0.2 <= second_delay <= 0.4 |
|
|
|
|
|
|
|
def test_concurrent_handler_execution(self): |
|
|
|
"""Test that multiple handlers can run concurrently without deadlock.""" |
|
|
|
results = [] |
|
|
|
errors = [] |
|
|
|
|
|
|
|
def run_handler(): |
|
|
|
try: |
|
|
|
with patch( |
|
|
|
"events.event_handlers.update_provider_when_message_created.db" |
|
|
|
) as mock_db: |
|
|
|
mock_query = Mock() |
|
|
|
mock_db.session.query.return_value = mock_query |
|
|
|
mock_query.filter.return_value = mock_query |
|
|
|
mock_query.order_by.return_value = mock_query |
|
|
|
mock_query.update.return_value = 1 |
|
|
|
|
|
|
|
handle( |
|
|
|
self.mock_message, |
|
|
|
application_generate_entity=self.mock_generate_entity, |
|
|
|
) |
|
|
|
results.append("success") |
|
|
|
except Exception as e: |
|
|
|
errors.append(str(e)) |
|
|
|
|
|
|
|
# Run multiple handlers concurrently |
|
|
|
threads = [] |
|
|
|
for _ in range(5): |
|
|
|
thread = threading.Thread(target=run_handler) |
|
|
|
threads.append(thread) |
|
|
|
thread.start() |
|
|
|
|
|
|
|
# Wait for all threads to complete |
|
|
|
for thread in threads: |
|
|
|
thread.join(timeout=5) |
|
|
|
|
|
|
|
# Verify all handlers completed successfully |
|
|
|
assert len(results) == 5 |
|
|
|
assert len(errors) == 0 |
|
|
|
|
|
|
|
def test_performance_stats_tracking(self): |
|
|
|
"""Test that performance statistics are tracked correctly.""" |
|
|
|
# Reset stats |
|
|
|
stats = get_update_stats() |
|
|
|
initial_total = stats["total_updates"] |
|
|
|
|
|
|
|
with patch( |
|
|
|
"events.event_handlers.update_provider_when_message_created.db" |
|
|
|
) as mock_db: |
|
|
|
mock_query = Mock() |
|
|
|
mock_db.session.query.return_value = mock_query |
|
|
|
mock_query.filter.return_value = mock_query |
|
|
|
mock_query.order_by.return_value = mock_query |
|
|
|
mock_query.update.return_value = 1 |
|
|
|
|
|
|
|
# Call handler |
|
|
|
handle( |
|
|
|
self.mock_message, application_generate_entity=self.mock_generate_entity |
|
|
|
) |
|
|
|
|
|
|
|
# Check that stats were updated |
|
|
|
updated_stats = get_update_stats() |
|
|
|
assert updated_stats["total_updates"] == initial_total + 1 |
|
|
|
assert updated_stats["successful_updates"] >= initial_total + 1 |
|
|
|
|
|
|
|
def test_non_chat_entity_ignored(self): |
|
|
|
"""Test that non-chat entities are ignored by the handler.""" |
|
|
|
# Create a non-chat entity |
|
|
|
mock_non_chat_entity = Mock() |
|
|
|
mock_non_chat_entity.__class__.__name__ = "NonChatEntity" |
|
|
|
|
|
|
|
with patch( |
|
|
|
"events.event_handlers.update_provider_when_message_created.db" |
|
|
|
) as mock_db: |
|
|
|
# Call handler with non-chat entity |
|
|
|
handle(self.mock_message, application_generate_entity=mock_non_chat_entity) |
|
|
|
|
|
|
|
# Verify no database operations were performed |
|
|
|
assert not mock_db.session.query.called |
|
|
|
assert not mock_db.session.commit.called |
|
|
|
|
|
|
|
@patch("events.event_handlers.update_provider_when_message_created.db") |
|
|
|
def test_quota_calculation_tokens(self, mock_db): |
|
|
|
"""Test quota calculation for token-based quotas.""" |
|
|
|
# Setup token-based quota |
|
|
|
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS |
|
|
|
self.mock_message.answer_tokens = 150 |
|
|
|
|
|
|
|
mock_query = Mock() |
|
|
|
mock_db.session.query.return_value = mock_query |
|
|
|
mock_query.filter.return_value = mock_query |
|
|
|
mock_query.order_by.return_value = mock_query |
|
|
|
mock_query.update.return_value = 1 |
|
|
|
|
|
|
|
# Call handler |
|
|
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity) |
|
|
|
|
|
|
|
# Verify update was called with token count |
|
|
|
update_calls = mock_query.update.call_args_list |
|
|
|
|
|
|
|
# Should have at least one call with quota_used update |
|
|
|
quota_update_found = False |
|
|
|
for call in update_calls: |
|
|
|
values = call[0][0] # First argument to update() |
|
|
|
if "quota_used" in values: |
|
|
|
quota_update_found = True |
|
|
|
break |
|
|
|
|
|
|
|
assert quota_update_found |
|
|
|
|
|
|
|
@patch("events.event_handlers.update_provider_when_message_created.db") |
|
|
|
def test_quota_calculation_times(self, mock_db): |
|
|
|
"""Test quota calculation for times-based quotas.""" |
|
|
|
# Setup times-based quota |
|
|
|
self.mock_system_config.current_quota_type = QuotaUnit.TIMES |
|
|
|
|
|
|
|
mock_query = Mock() |
|
|
|
mock_db.session.query.return_value = mock_query |
|
|
|
mock_query.filter.return_value = mock_query |
|
|
|
mock_query.order_by.return_value = mock_query |
|
|
|
mock_query.update.return_value = 1 |
|
|
|
|
|
|
|
# Call handler |
|
|
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity) |
|
|
|
|
|
|
|
# Verify update was called |
|
|
|
assert mock_query.update.called |
|
|
|
assert mock_db.session.commit.called |