You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_provider_update_deadlock_prevention.py 9.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import threading
  2. from unittest.mock import Mock, patch
  3. from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
  4. from core.entities.provider_entities import QuotaUnit
  5. from events.event_handlers.update_provider_when_message_created import (
  6. handle,
  7. get_update_stats,
  8. )
  9. from models.provider import ProviderType
  10. from sqlalchemy.exc import OperationalError
  11. class TestProviderUpdateDeadlockPrevention:
  12. """Test suite for deadlock prevention in Provider updates."""
  13. def setup_method(self):
  14. """Setup test fixtures."""
  15. self.mock_message = Mock()
  16. self.mock_message.answer_tokens = 100
  17. self.mock_app_config = Mock()
  18. self.mock_app_config.tenant_id = "test-tenant-123"
  19. self.mock_model_conf = Mock()
  20. self.mock_model_conf.provider = "openai"
  21. self.mock_system_config = Mock()
  22. self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
  23. self.mock_provider_config = Mock()
  24. self.mock_provider_config.using_provider_type = ProviderType.SYSTEM
  25. self.mock_provider_config.system_configuration = self.mock_system_config
  26. self.mock_provider_bundle = Mock()
  27. self.mock_provider_bundle.configuration = self.mock_provider_config
  28. self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle
  29. self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity)
  30. self.mock_generate_entity.app_config = self.mock_app_config
  31. self.mock_generate_entity.model_conf = self.mock_model_conf
  32. @patch("events.event_handlers.update_provider_when_message_created.db")
  33. def test_consolidated_handler_basic_functionality(self, mock_db):
  34. """Test that the consolidated handler performs both updates correctly."""
  35. # Setup mock query chain
  36. mock_query = Mock()
  37. mock_db.session.query.return_value = mock_query
  38. mock_query.filter.return_value = mock_query
  39. mock_query.order_by.return_value = mock_query
  40. mock_query.update.return_value = 1 # 1 row affected
  41. # Call the handler
  42. handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
  43. # Verify db.session.query was called
  44. assert mock_db.session.query.called
  45. # Verify commit was called
  46. mock_db.session.commit.assert_called_once()
  47. # Verify no rollback was called
  48. assert not mock_db.session.rollback.called
  49. @patch("events.event_handlers.update_provider_when_message_created.db")
  50. def test_deadlock_retry_mechanism(self, mock_db):
  51. """Test that deadlock errors trigger retry logic."""
  52. # Setup mock to raise deadlock error on first attempt, succeed on second
  53. mock_query = Mock()
  54. mock_db.session.query.return_value = mock_query
  55. mock_query.filter.return_value = mock_query
  56. mock_query.order_by.return_value = mock_query
  57. mock_query.update.return_value = 1
  58. # First call raises deadlock, second succeeds
  59. mock_db.session.commit.side_effect = [
  60. OperationalError("deadlock detected", None, None),
  61. None, # Success on retry
  62. ]
  63. # Call the handler
  64. handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
  65. # Verify commit was called twice (original + retry)
  66. assert mock_db.session.commit.call_count == 2
  67. # Verify rollback was called once (after first failure)
  68. mock_db.session.rollback.assert_called_once()
  69. @patch("events.event_handlers.update_provider_when_message_created.db")
  70. @patch("events.event_handlers.update_provider_when_message_created.time.sleep")
  71. def test_exponential_backoff_timing(self, mock_sleep, mock_db):
  72. """Test that retry delays follow exponential backoff pattern."""
  73. # Setup mock to fail twice, succeed on third attempt
  74. mock_query = Mock()
  75. mock_db.session.query.return_value = mock_query
  76. mock_query.filter.return_value = mock_query
  77. mock_query.order_by.return_value = mock_query
  78. mock_query.update.return_value = 1
  79. mock_db.session.commit.side_effect = [
  80. OperationalError("deadlock detected", None, None),
  81. OperationalError("deadlock detected", None, None),
  82. None, # Success on third attempt
  83. ]
  84. # Call the handler
  85. handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
  86. # Verify sleep was called twice with increasing delays
  87. assert mock_sleep.call_count == 2
  88. # First delay should be around 0.1s + jitter
  89. first_delay = mock_sleep.call_args_list[0][0][0]
  90. assert 0.1 <= first_delay <= 0.3
  91. # Second delay should be around 0.2s + jitter
  92. second_delay = mock_sleep.call_args_list[1][0][0]
  93. assert 0.2 <= second_delay <= 0.4
  94. def test_concurrent_handler_execution(self):
  95. """Test that multiple handlers can run concurrently without deadlock."""
  96. results = []
  97. errors = []
  98. def run_handler():
  99. try:
  100. with patch(
  101. "events.event_handlers.update_provider_when_message_created.db"
  102. ) as mock_db:
  103. mock_query = Mock()
  104. mock_db.session.query.return_value = mock_query
  105. mock_query.filter.return_value = mock_query
  106. mock_query.order_by.return_value = mock_query
  107. mock_query.update.return_value = 1
  108. handle(
  109. self.mock_message,
  110. application_generate_entity=self.mock_generate_entity,
  111. )
  112. results.append("success")
  113. except Exception as e:
  114. errors.append(str(e))
  115. # Run multiple handlers concurrently
  116. threads = []
  117. for _ in range(5):
  118. thread = threading.Thread(target=run_handler)
  119. threads.append(thread)
  120. thread.start()
  121. # Wait for all threads to complete
  122. for thread in threads:
  123. thread.join(timeout=5)
  124. # Verify all handlers completed successfully
  125. assert len(results) == 5
  126. assert len(errors) == 0
  127. def test_performance_stats_tracking(self):
  128. """Test that performance statistics are tracked correctly."""
  129. # Reset stats
  130. stats = get_update_stats()
  131. initial_total = stats["total_updates"]
  132. with patch(
  133. "events.event_handlers.update_provider_when_message_created.db"
  134. ) as mock_db:
  135. mock_query = Mock()
  136. mock_db.session.query.return_value = mock_query
  137. mock_query.filter.return_value = mock_query
  138. mock_query.order_by.return_value = mock_query
  139. mock_query.update.return_value = 1
  140. # Call handler
  141. handle(
  142. self.mock_message, application_generate_entity=self.mock_generate_entity
  143. )
  144. # Check that stats were updated
  145. updated_stats = get_update_stats()
  146. assert updated_stats["total_updates"] == initial_total + 1
  147. assert updated_stats["successful_updates"] >= initial_total + 1
  148. def test_non_chat_entity_ignored(self):
  149. """Test that non-chat entities are ignored by the handler."""
  150. # Create a non-chat entity
  151. mock_non_chat_entity = Mock()
  152. mock_non_chat_entity.__class__.__name__ = "NonChatEntity"
  153. with patch(
  154. "events.event_handlers.update_provider_when_message_created.db"
  155. ) as mock_db:
  156. # Call handler with non-chat entity
  157. handle(self.mock_message, application_generate_entity=mock_non_chat_entity)
  158. # Verify no database operations were performed
  159. assert not mock_db.session.query.called
  160. assert not mock_db.session.commit.called
  161. @patch("events.event_handlers.update_provider_when_message_created.db")
  162. def test_quota_calculation_tokens(self, mock_db):
  163. """Test quota calculation for token-based quotas."""
  164. # Setup token-based quota
  165. self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
  166. self.mock_message.answer_tokens = 150
  167. mock_query = Mock()
  168. mock_db.session.query.return_value = mock_query
  169. mock_query.filter.return_value = mock_query
  170. mock_query.order_by.return_value = mock_query
  171. mock_query.update.return_value = 1
  172. # Call handler
  173. handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
  174. # Verify update was called with token count
  175. update_calls = mock_query.update.call_args_list
  176. # Should have at least one call with quota_used update
  177. quota_update_found = False
  178. for call in update_calls:
  179. values = call[0][0] # First argument to update()
  180. if "quota_used" in values:
  181. quota_update_found = True
  182. break
  183. assert quota_update_found
  184. @patch("events.event_handlers.update_provider_when_message_created.db")
  185. def test_quota_calculation_times(self, mock_db):
  186. """Test quota calculation for times-based quotas."""
  187. # Setup times-based quota
  188. self.mock_system_config.current_quota_type = QuotaUnit.TIMES
  189. mock_query = Mock()
  190. mock_db.session.query.return_value = mock_query
  191. mock_query.filter.return_value = mock_query
  192. mock_query.order_by.return_value = mock_query
  193. mock_query.update.return_value = 1
  194. # Call handler
  195. handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
  196. # Verify update was called
  197. assert mock_query.update.called
  198. assert mock_db.session.commit.called