You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

update_provider_when_message_created.py 8.8KB


  1. import logging
  2. import time as time_module
  3. from datetime import datetime
  4. from typing import Any, Optional
  5. from pydantic import BaseModel
  6. from sqlalchemy import update
  7. from sqlalchemy.orm import Session
  8. from configs import dify_config
  9. from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
  10. from core.entities.provider_entities import QuotaUnit, SystemConfiguration
  11. from core.plugin.entities.plugin import ModelProviderID
  12. from events.message_event import message_was_created
  13. from extensions.ext_database import db
  14. from libs import datetime_utils
  15. from models.model import Message
  16. from models.provider import Provider, ProviderType
  17. logger = logging.getLogger(__name__)
  18. class _ProviderUpdateFilters(BaseModel):
  19. """Filters for identifying Provider records to update."""
  20. tenant_id: str
  21. provider_name: str
  22. provider_type: Optional[str] = None
  23. quota_type: Optional[str] = None
  24. class _ProviderUpdateAdditionalFilters(BaseModel):
  25. """Additional filters for Provider updates."""
  26. quota_limit_check: bool = False
  27. class _ProviderUpdateValues(BaseModel):
  28. """Values to update in Provider records."""
  29. last_used: Optional[datetime] = None
  30. quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
  31. class _ProviderUpdateOperation(BaseModel):
  32. """A single Provider update operation."""
  33. filters: _ProviderUpdateFilters
  34. values: _ProviderUpdateValues
  35. additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
  36. description: str = "unknown"
  37. @message_was_created.connect
  38. def handle(sender: Message, **kwargs):
  39. """
  40. Consolidated handler for Provider updates when a message is created.
  41. This handler replaces both:
  42. - update_provider_last_used_at_when_message_created
  43. - deduct_quota_when_message_created
  44. By performing all Provider updates in a single transaction, we ensure
  45. consistency and efficiency when updating Provider records.
  46. """
  47. message = sender
  48. application_generate_entity = kwargs.get("application_generate_entity")
  49. if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
  50. return
  51. tenant_id = application_generate_entity.app_config.tenant_id
  52. provider_name = application_generate_entity.model_conf.provider
  53. current_time = datetime_utils.naive_utc_now()
  54. # Prepare updates for both scenarios
  55. updates_to_perform: list[_ProviderUpdateOperation] = []
  56. # 1. Always update last_used for the provider
  57. basic_update = _ProviderUpdateOperation(
  58. filters=_ProviderUpdateFilters(
  59. tenant_id=tenant_id,
  60. provider_name=provider_name,
  61. ),
  62. values=_ProviderUpdateValues(last_used=current_time),
  63. description="basic_last_used_update",
  64. )
  65. updates_to_perform.append(basic_update)
  66. # 2. Check if we need to deduct quota (system provider only)
  67. model_config = application_generate_entity.model_conf
  68. provider_model_bundle = model_config.provider_model_bundle
  69. provider_configuration = provider_model_bundle.configuration
  70. if (
  71. provider_configuration.using_provider_type == ProviderType.SYSTEM
  72. and provider_configuration.system_configuration
  73. and provider_configuration.system_configuration.current_quota_type is not None
  74. ):
  75. system_configuration = provider_configuration.system_configuration
  76. # Calculate quota usage
  77. used_quota = _calculate_quota_usage(
  78. message=message,
  79. system_configuration=system_configuration,
  80. model_name=model_config.model,
  81. )
  82. if used_quota is not None:
  83. quota_update = _ProviderUpdateOperation(
  84. filters=_ProviderUpdateFilters(
  85. tenant_id=tenant_id,
  86. provider_name=ModelProviderID(model_config.provider).provider_name,
  87. provider_type=ProviderType.SYSTEM.value,
  88. quota_type=provider_configuration.system_configuration.current_quota_type.value,
  89. ),
  90. values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
  91. additional_filters=_ProviderUpdateAdditionalFilters(
  92. quota_limit_check=True # Provider.quota_limit > Provider.quota_used
  93. ),
  94. description="quota_deduction_update",
  95. )
  96. updates_to_perform.append(quota_update)
  97. # Execute all updates
  98. start_time = time_module.perf_counter()
  99. try:
  100. _execute_provider_updates(updates_to_perform)
  101. # Log successful completion with timing
  102. duration = time_module.perf_counter() - start_time
  103. logger.info(
  104. f"Provider updates completed successfully. "
  105. f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, "
  106. f"Tenant: {tenant_id}, Provider: {provider_name}"
  107. )
  108. except Exception as e:
  109. # Log failure with timing and context
  110. duration = time_module.perf_counter() - start_time
  111. logger.exception(
  112. f"Provider updates failed after {duration:.3f}s. "
  113. f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, "
  114. f"Provider: {provider_name}"
  115. )
  116. raise
  117. def _calculate_quota_usage(
  118. *, message: Message, system_configuration: SystemConfiguration, model_name: str
  119. ) -> Optional[int]:
  120. """Calculate quota usage based on message tokens and quota type."""
  121. quota_unit = None
  122. for quota_configuration in system_configuration.quota_configurations:
  123. if quota_configuration.quota_type == system_configuration.current_quota_type:
  124. quota_unit = quota_configuration.quota_unit
  125. if quota_configuration.quota_limit == -1:
  126. return None
  127. break
  128. if quota_unit is None:
  129. return None
  130. try:
  131. if quota_unit == QuotaUnit.TOKENS:
  132. tokens = message.message_tokens + message.answer_tokens
  133. return tokens
  134. if quota_unit == QuotaUnit.CREDITS:
  135. tokens = dify_config.get_model_credits(model_name)
  136. return tokens
  137. elif quota_unit == QuotaUnit.TIMES:
  138. return 1
  139. return None
  140. except Exception as e:
  141. logger.exception("Failed to calculate quota usage")
  142. return None
  143. def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
  144. """Execute all Provider updates in a single transaction."""
  145. if not updates_to_perform:
  146. return
  147. # Use SQLAlchemy's context manager for transaction management
  148. # This automatically handles commit/rollback
  149. with Session(db.engine) as session:
  150. # Use a single transaction for all updates
  151. for update_operation in updates_to_perform:
  152. filters = update_operation.filters
  153. values = update_operation.values
  154. additional_filters = update_operation.additional_filters
  155. description = update_operation.description
  156. # Build the where conditions
  157. where_conditions = [
  158. Provider.tenant_id == filters.tenant_id,
  159. Provider.provider_name == filters.provider_name,
  160. ]
  161. # Add additional filters if specified
  162. if filters.provider_type is not None:
  163. where_conditions.append(Provider.provider_type == filters.provider_type)
  164. if filters.quota_type is not None:
  165. where_conditions.append(Provider.quota_type == filters.quota_type)
  166. if additional_filters.quota_limit_check:
  167. where_conditions.append(Provider.quota_limit > Provider.quota_used)
  168. # Prepare values dict for SQLAlchemy update
  169. update_values = {}
  170. if values.last_used is not None:
  171. update_values["last_used"] = values.last_used
  172. if values.quota_used is not None:
  173. update_values["quota_used"] = values.quota_used
  174. # Build and execute the update statement
  175. stmt = update(Provider).where(*where_conditions).values(**update_values)
  176. result = session.execute(stmt)
  177. rows_affected = result.rowcount
  178. logger.debug(
  179. f"Provider update ({description}): {rows_affected} rows affected. "
  180. f"Filters: {filters.model_dump()}, Values: {update_values}"
  181. )
  182. # If no rows were affected for quota updates, log a warning
  183. if rows_affected == 0 and description == "quota_deduction_update":
  184. logger.warning(
  185. f"No Provider rows updated for quota deduction. "
  186. f"This may indicate quota limit exceeded or provider not found. "
  187. f"Filters: {filters.model_dump()}"
  188. )
  189. logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates")