You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

update_provider_when_message_created.py 8.9KB


  1. import logging
  2. import time as time_module
  3. from datetime import datetime
  4. from typing import Any, Optional
  5. from pydantic import BaseModel
  6. from sqlalchemy import update
  7. from sqlalchemy.orm import Session
  8. from configs import dify_config
  9. from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
  10. from core.entities.provider_entities import QuotaUnit, SystemConfiguration
  11. from core.plugin.entities.plugin import ModelProviderID
  12. from events.message_event import message_was_created
  13. from extensions.ext_database import db
  14. from libs import datetime_utils
  15. from models.model import Message
  16. from models.provider import Provider, ProviderType
  17. logger = logging.getLogger(__name__)
  18. class _ProviderUpdateFilters(BaseModel):
  19. """Filters for identifying Provider records to update."""
  20. tenant_id: str
  21. provider_name: str
  22. provider_type: Optional[str] = None
  23. quota_type: Optional[str] = None
  24. class _ProviderUpdateAdditionalFilters(BaseModel):
  25. """Additional filters for Provider updates."""
  26. quota_limit_check: bool = False
  27. class _ProviderUpdateValues(BaseModel):
  28. """Values to update in Provider records."""
  29. last_used: Optional[datetime] = None
  30. quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
  31. class _ProviderUpdateOperation(BaseModel):
  32. """A single Provider update operation."""
  33. filters: _ProviderUpdateFilters
  34. values: _ProviderUpdateValues
  35. additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
  36. description: str = "unknown"
  37. @message_was_created.connect
  38. def handle(sender: Message, **kwargs):
  39. """
  40. Consolidated handler for Provider updates when a message is created.
  41. This handler replaces both:
  42. - update_provider_last_used_at_when_message_created
  43. - deduct_quota_when_message_created
  44. By performing all Provider updates in a single transaction, we ensure
  45. consistency and efficiency when updating Provider records.
  46. """
  47. message = sender
  48. application_generate_entity = kwargs.get("application_generate_entity")
  49. if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
  50. return
  51. tenant_id = application_generate_entity.app_config.tenant_id
  52. provider_name = application_generate_entity.model_conf.provider
  53. current_time = datetime_utils.naive_utc_now()
  54. # Prepare updates for both scenarios
  55. updates_to_perform: list[_ProviderUpdateOperation] = []
  56. # 1. Always update last_used for the provider
  57. basic_update = _ProviderUpdateOperation(
  58. filters=_ProviderUpdateFilters(
  59. tenant_id=tenant_id,
  60. provider_name=provider_name,
  61. ),
  62. values=_ProviderUpdateValues(last_used=current_time),
  63. description="basic_last_used_update",
  64. )
  65. updates_to_perform.append(basic_update)
  66. # 2. Check if we need to deduct quota (system provider only)
  67. model_config = application_generate_entity.model_conf
  68. provider_model_bundle = model_config.provider_model_bundle
  69. provider_configuration = provider_model_bundle.configuration
  70. if (
  71. provider_configuration.using_provider_type == ProviderType.SYSTEM
  72. and provider_configuration.system_configuration
  73. and provider_configuration.system_configuration.current_quota_type is not None
  74. ):
  75. system_configuration = provider_configuration.system_configuration
  76. # Calculate quota usage
  77. used_quota = _calculate_quota_usage(
  78. message=message,
  79. system_configuration=system_configuration,
  80. model_name=model_config.model,
  81. )
  82. if used_quota is not None:
  83. quota_update = _ProviderUpdateOperation(
  84. filters=_ProviderUpdateFilters(
  85. tenant_id=tenant_id,
  86. provider_name=ModelProviderID(model_config.provider).provider_name,
  87. provider_type=ProviderType.SYSTEM.value,
  88. quota_type=provider_configuration.system_configuration.current_quota_type.value,
  89. ),
  90. values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
  91. additional_filters=_ProviderUpdateAdditionalFilters(
  92. quota_limit_check=True # Provider.quota_limit > Provider.quota_used
  93. ),
  94. description="quota_deduction_update",
  95. )
  96. updates_to_perform.append(quota_update)
  97. # Execute all updates
  98. start_time = time_module.perf_counter()
  99. try:
  100. _execute_provider_updates(updates_to_perform)
  101. # Log successful completion with timing
  102. duration = time_module.perf_counter() - start_time
  103. logger.info(
  104. "Provider updates completed successfully. Updates: %s, Duration: %s s, Tenant: %s, Provider: %s",
  105. len(updates_to_perform),
  106. duration,
  107. tenant_id,
  108. provider_name,
  109. )
  110. except Exception as e:
  111. # Log failure with timing and context
  112. duration = time_module.perf_counter() - start_time
  113. logger.exception(
  114. "Provider updates failed after %s s. Updates: %s, Tenant: %s, Provider: %s",
  115. duration,
  116. len(updates_to_perform),
  117. tenant_id,
  118. provider_name,
  119. )
  120. raise
  121. def _calculate_quota_usage(
  122. *, message: Message, system_configuration: SystemConfiguration, model_name: str
  123. ) -> Optional[int]:
  124. """Calculate quota usage based on message tokens and quota type."""
  125. quota_unit = None
  126. for quota_configuration in system_configuration.quota_configurations:
  127. if quota_configuration.quota_type == system_configuration.current_quota_type:
  128. quota_unit = quota_configuration.quota_unit
  129. if quota_configuration.quota_limit == -1:
  130. return None
  131. break
  132. if quota_unit is None:
  133. return None
  134. try:
  135. if quota_unit == QuotaUnit.TOKENS:
  136. tokens = message.message_tokens + message.answer_tokens
  137. return tokens
  138. if quota_unit == QuotaUnit.CREDITS:
  139. tokens = dify_config.get_model_credits(model_name)
  140. return tokens
  141. elif quota_unit == QuotaUnit.TIMES:
  142. return 1
  143. return None
  144. except Exception as e:
  145. logger.exception("Failed to calculate quota usage")
  146. return None
  147. def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
  148. """Execute all Provider updates in a single transaction."""
  149. if not updates_to_perform:
  150. return
  151. # Use SQLAlchemy's context manager for transaction management
  152. # This automatically handles commit/rollback
  153. with Session(db.engine) as session:
  154. # Use a single transaction for all updates
  155. for update_operation in updates_to_perform:
  156. filters = update_operation.filters
  157. values = update_operation.values
  158. additional_filters = update_operation.additional_filters
  159. description = update_operation.description
  160. # Build the where conditions
  161. where_conditions = [
  162. Provider.tenant_id == filters.tenant_id,
  163. Provider.provider_name == filters.provider_name,
  164. ]
  165. # Add additional filters if specified
  166. if filters.provider_type is not None:
  167. where_conditions.append(Provider.provider_type == filters.provider_type)
  168. if filters.quota_type is not None:
  169. where_conditions.append(Provider.quota_type == filters.quota_type)
  170. if additional_filters.quota_limit_check:
  171. where_conditions.append(Provider.quota_limit > Provider.quota_used)
  172. # Prepare values dict for SQLAlchemy update
  173. update_values = {}
  174. if values.last_used is not None:
  175. update_values["last_used"] = values.last_used
  176. if values.quota_used is not None:
  177. update_values["quota_used"] = values.quota_used
  178. # Build and execute the update statement
  179. stmt = update(Provider).where(*where_conditions).values(**update_values)
  180. result = session.execute(stmt)
  181. rows_affected = result.rowcount
  182. logger.debug(
  183. "Provider update (%s): %s rows affected. Filters: %s, Values: %s",
  184. description,
  185. rows_affected,
  186. filters.model_dump(),
  187. update_values,
  188. )
  189. # If no rows were affected for quota updates, log a warning
  190. if rows_affected == 0 and description == "quota_deduction_update":
  191. logger.warning(
  192. "No Provider rows updated for quota deduction. "
  193. "This may indicate quota limit exceeded or provider not found. "
  194. "Filters: %s",
  195. filters.model_dump(),
  196. )
  197. logger.debug("Successfully processed %s Provider updates", len(updates_to_perform))