您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

update_provider_when_message_created.py 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import logging
  2. import time as time_module
  3. from datetime import datetime
  4. from typing import Any, Optional
  5. from pydantic import BaseModel
  6. from sqlalchemy import update
  7. from configs import dify_config
  8. from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
  9. from core.entities.provider_entities import QuotaUnit, SystemConfiguration
  10. from core.plugin.entities.plugin import ModelProviderID
  11. from events.message_event import message_was_created
  12. from extensions.ext_database import db
  13. from libs import datetime_utils
  14. from models.model import Message
  15. from models.provider import Provider, ProviderType
  16. logger = logging.getLogger(__name__)
  17. class _ProviderUpdateFilters(BaseModel):
  18. """Filters for identifying Provider records to update."""
  19. tenant_id: str
  20. provider_name: str
  21. provider_type: Optional[str] = None
  22. quota_type: Optional[str] = None
  23. class _ProviderUpdateAdditionalFilters(BaseModel):
  24. """Additional filters for Provider updates."""
  25. quota_limit_check: bool = False
  26. class _ProviderUpdateValues(BaseModel):
  27. """Values to update in Provider records."""
  28. last_used: Optional[datetime] = None
  29. quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
  30. class _ProviderUpdateOperation(BaseModel):
  31. """A single Provider update operation."""
  32. filters: _ProviderUpdateFilters
  33. values: _ProviderUpdateValues
  34. additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
  35. description: str = "unknown"
  36. @message_was_created.connect
  37. def handle(sender: Message, **kwargs):
  38. """
  39. Consolidated handler for Provider updates when a message is created.
  40. This handler replaces both:
  41. - update_provider_last_used_at_when_message_created
  42. - deduct_quota_when_message_created
  43. By performing all Provider updates in a single transaction, we ensure
  44. consistency and efficiency when updating Provider records.
  45. """
  46. message = sender
  47. application_generate_entity = kwargs.get("application_generate_entity")
  48. if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
  49. return
  50. tenant_id = application_generate_entity.app_config.tenant_id
  51. provider_name = application_generate_entity.model_conf.provider
  52. current_time = datetime_utils.naive_utc_now()
  53. # Prepare updates for both scenarios
  54. updates_to_perform: list[_ProviderUpdateOperation] = []
  55. # 1. Always update last_used for the provider
  56. basic_update = _ProviderUpdateOperation(
  57. filters=_ProviderUpdateFilters(
  58. tenant_id=tenant_id,
  59. provider_name=provider_name,
  60. ),
  61. values=_ProviderUpdateValues(last_used=current_time),
  62. description="basic_last_used_update",
  63. )
  64. updates_to_perform.append(basic_update)
  65. # 2. Check if we need to deduct quota (system provider only)
  66. model_config = application_generate_entity.model_conf
  67. provider_model_bundle = model_config.provider_model_bundle
  68. provider_configuration = provider_model_bundle.configuration
  69. if (
  70. provider_configuration.using_provider_type == ProviderType.SYSTEM
  71. and provider_configuration.system_configuration
  72. and provider_configuration.system_configuration.current_quota_type is not None
  73. ):
  74. system_configuration = provider_configuration.system_configuration
  75. # Calculate quota usage
  76. used_quota = _calculate_quota_usage(
  77. message=message,
  78. system_configuration=system_configuration,
  79. model_name=model_config.model,
  80. )
  81. if used_quota is not None:
  82. quota_update = _ProviderUpdateOperation(
  83. filters=_ProviderUpdateFilters(
  84. tenant_id=tenant_id,
  85. provider_name=ModelProviderID(model_config.provider).provider_name,
  86. provider_type=ProviderType.SYSTEM.value,
  87. quota_type=provider_configuration.system_configuration.current_quota_type.value,
  88. ),
  89. values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
  90. additional_filters=_ProviderUpdateAdditionalFilters(
  91. quota_limit_check=True # Provider.quota_limit > Provider.quota_used
  92. ),
  93. description="quota_deduction_update",
  94. )
  95. updates_to_perform.append(quota_update)
  96. # Execute all updates
  97. start_time = time_module.perf_counter()
  98. try:
  99. _execute_provider_updates(updates_to_perform)
  100. # Log successful completion with timing
  101. duration = time_module.perf_counter() - start_time
  102. logger.info(
  103. f"Provider updates completed successfully. "
  104. f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, "
  105. f"Tenant: {tenant_id}, Provider: {provider_name}"
  106. )
  107. except Exception as e:
  108. # Log failure with timing and context
  109. duration = time_module.perf_counter() - start_time
  110. logger.exception(
  111. f"Provider updates failed after {duration:.3f}s. "
  112. f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, "
  113. f"Provider: {provider_name}"
  114. )
  115. raise
  116. def _calculate_quota_usage(
  117. *, message: Message, system_configuration: SystemConfiguration, model_name: str
  118. ) -> Optional[int]:
  119. """Calculate quota usage based on message tokens and quota type."""
  120. quota_unit = None
  121. for quota_configuration in system_configuration.quota_configurations:
  122. if quota_configuration.quota_type == system_configuration.current_quota_type:
  123. quota_unit = quota_configuration.quota_unit
  124. if quota_configuration.quota_limit == -1:
  125. return None
  126. break
  127. if quota_unit is None:
  128. return None
  129. try:
  130. if quota_unit == QuotaUnit.TOKENS:
  131. tokens = message.message_tokens + message.answer_tokens
  132. return tokens
  133. if quota_unit == QuotaUnit.CREDITS:
  134. tokens = dify_config.get_model_credits(model_name)
  135. return tokens
  136. elif quota_unit == QuotaUnit.TIMES:
  137. return 1
  138. return None
  139. except Exception as e:
  140. logger.exception("Failed to calculate quota usage")
  141. return None
  142. def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
  143. """Execute all Provider updates in a single transaction."""
  144. if not updates_to_perform:
  145. return
  146. # Use SQLAlchemy's context manager for transaction management
  147. # This automatically handles commit/rollback
  148. with db.session.begin():
  149. # Use a single transaction for all updates
  150. for update_operation in updates_to_perform:
  151. filters = update_operation.filters
  152. values = update_operation.values
  153. additional_filters = update_operation.additional_filters
  154. description = update_operation.description
  155. # Build the where conditions
  156. where_conditions = [
  157. Provider.tenant_id == filters.tenant_id,
  158. Provider.provider_name == filters.provider_name,
  159. ]
  160. # Add additional filters if specified
  161. if filters.provider_type is not None:
  162. where_conditions.append(Provider.provider_type == filters.provider_type)
  163. if filters.quota_type is not None:
  164. where_conditions.append(Provider.quota_type == filters.quota_type)
  165. if additional_filters.quota_limit_check:
  166. where_conditions.append(Provider.quota_limit > Provider.quota_used)
  167. # Prepare values dict for SQLAlchemy update
  168. update_values = {}
  169. if values.last_used is not None:
  170. update_values["last_used"] = values.last_used
  171. if values.quota_used is not None:
  172. update_values["quota_used"] = values.quota_used
  173. # Build and execute the update statement
  174. stmt = update(Provider).where(*where_conditions).values(**update_values)
  175. result = db.session.execute(stmt)
  176. rows_affected = result.rowcount
  177. logger.debug(
  178. f"Provider update ({description}): {rows_affected} rows affected. "
  179. f"Filters: {filters.model_dump()}, Values: {update_values}"
  180. )
  181. # If no rows were affected for quota updates, log a warning
  182. if rows_affected == 0 and description == "quota_deduction_update":
  183. logger.warning(
  184. f"No Provider rows updated for quota deduction. "
  185. f"This may indicate quota limit exceeded or provider not found. "
  186. f"Filters: {filters.model_dump()}"
  187. )
  188. logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates")