您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

update_provider_when_message_created.py 9.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import logging
  2. import time as time_module
  3. from datetime import datetime
  4. from typing import Any, Optional
  5. from pydantic import BaseModel
  6. from sqlalchemy import update
  7. from sqlalchemy.orm import Session
  8. from configs import dify_config
  9. from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
  10. from core.entities.provider_entities import QuotaUnit, SystemConfiguration
  11. from core.plugin.entities.plugin import ModelProviderID
  12. from events.message_event import message_was_created
  13. from extensions.ext_database import db
  14. from libs import datetime_utils
  15. from models.model import Message
  16. from models.provider import Provider, ProviderType
  17. logger = logging.getLogger(__name__)
  18. class _ProviderUpdateFilters(BaseModel):
  19. """Filters for identifying Provider records to update."""
  20. tenant_id: str
  21. provider_name: str
  22. provider_type: Optional[str] = None
  23. quota_type: Optional[str] = None
  24. class _ProviderUpdateAdditionalFilters(BaseModel):
  25. """Additional filters for Provider updates."""
  26. quota_limit_check: bool = False
  27. class _ProviderUpdateValues(BaseModel):
  28. """Values to update in Provider records."""
  29. last_used: Optional[datetime] = None
  30. quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
  31. class _ProviderUpdateOperation(BaseModel):
  32. """A single Provider update operation."""
  33. filters: _ProviderUpdateFilters
  34. values: _ProviderUpdateValues
  35. additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
  36. description: str = "unknown"
  37. @message_was_created.connect
  38. def handle(sender: Message, **kwargs):
  39. """
  40. Consolidated handler for Provider updates when a message is created.
  41. This handler replaces both:
  42. - update_provider_last_used_at_when_message_created
  43. - deduct_quota_when_message_created
  44. By performing all Provider updates in a single transaction, we ensure
  45. consistency and efficiency when updating Provider records.
  46. """
  47. message = sender
  48. application_generate_entity = kwargs.get("application_generate_entity")
  49. if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
  50. return
  51. tenant_id = application_generate_entity.app_config.tenant_id
  52. provider_name = application_generate_entity.model_conf.provider
  53. current_time = datetime_utils.naive_utc_now()
  54. # Prepare updates for both scenarios
  55. updates_to_perform: list[_ProviderUpdateOperation] = []
  56. # 1. Always update last_used for the provider
  57. basic_update = _ProviderUpdateOperation(
  58. filters=_ProviderUpdateFilters(
  59. tenant_id=tenant_id,
  60. provider_name=provider_name,
  61. ),
  62. values=_ProviderUpdateValues(last_used=current_time),
  63. description="basic_last_used_update",
  64. )
  65. logger.info("provider used, tenant_id=%s, provider_name=%s", tenant_id, provider_name)
  66. updates_to_perform.append(basic_update)
  67. # 2. Check if we need to deduct quota (system provider only)
  68. model_config = application_generate_entity.model_conf
  69. provider_model_bundle = model_config.provider_model_bundle
  70. provider_configuration = provider_model_bundle.configuration
  71. if (
  72. provider_configuration.using_provider_type == ProviderType.SYSTEM
  73. and provider_configuration.system_configuration
  74. and provider_configuration.system_configuration.current_quota_type is not None
  75. ):
  76. system_configuration = provider_configuration.system_configuration
  77. # Calculate quota usage
  78. used_quota = _calculate_quota_usage(
  79. message=message,
  80. system_configuration=system_configuration,
  81. model_name=model_config.model,
  82. )
  83. if used_quota is not None:
  84. quota_update = _ProviderUpdateOperation(
  85. filters=_ProviderUpdateFilters(
  86. tenant_id=tenant_id,
  87. provider_name=ModelProviderID(model_config.provider).provider_name,
  88. provider_type=ProviderType.SYSTEM.value,
  89. quota_type=provider_configuration.system_configuration.current_quota_type.value,
  90. ),
  91. values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
  92. additional_filters=_ProviderUpdateAdditionalFilters(
  93. quota_limit_check=True # Provider.quota_limit > Provider.quota_used
  94. ),
  95. description="quota_deduction_update",
  96. )
  97. updates_to_perform.append(quota_update)
  98. # Execute all updates
  99. start_time = time_module.perf_counter()
  100. try:
  101. _execute_provider_updates(updates_to_perform)
  102. # Log successful completion with timing
  103. duration = time_module.perf_counter() - start_time
  104. logger.info(
  105. "Provider updates completed successfully. Updates: %s, Duration: %s s, Tenant: %s, Provider: %s",
  106. len(updates_to_perform),
  107. duration,
  108. tenant_id,
  109. provider_name,
  110. )
  111. except Exception as e:
  112. # Log failure with timing and context
  113. duration = time_module.perf_counter() - start_time
  114. logger.exception(
  115. "Provider updates failed after %s s. Updates: %s, Tenant: %s, Provider: %s",
  116. duration,
  117. len(updates_to_perform),
  118. tenant_id,
  119. provider_name,
  120. )
  121. raise
  122. def _calculate_quota_usage(
  123. *, message: Message, system_configuration: SystemConfiguration, model_name: str
  124. ) -> Optional[int]:
  125. """Calculate quota usage based on message tokens and quota type."""
  126. quota_unit = None
  127. for quota_configuration in system_configuration.quota_configurations:
  128. if quota_configuration.quota_type == system_configuration.current_quota_type:
  129. quota_unit = quota_configuration.quota_unit
  130. if quota_configuration.quota_limit == -1:
  131. return None
  132. break
  133. if quota_unit is None:
  134. return None
  135. try:
  136. if quota_unit == QuotaUnit.TOKENS:
  137. tokens = message.message_tokens + message.answer_tokens
  138. return tokens
  139. if quota_unit == QuotaUnit.CREDITS:
  140. tokens = dify_config.get_model_credits(model_name)
  141. return tokens
  142. elif quota_unit == QuotaUnit.TIMES:
  143. return 1
  144. return None
  145. except Exception as e:
  146. logger.exception("Failed to calculate quota usage")
  147. return None
  148. def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
  149. """Execute all Provider updates in a single transaction."""
  150. if not updates_to_perform:
  151. return
  152. updates_to_perform = sorted(updates_to_perform, key=lambda i: (i.filters.tenant_id, i.filters.provider_name))
  153. # Use SQLAlchemy's context manager for transaction management
  154. # This automatically handles commit/rollback
  155. with Session(db.engine) as session, session.begin():
  156. # Use a single transaction for all updates
  157. for update_operation in updates_to_perform:
  158. filters = update_operation.filters
  159. values = update_operation.values
  160. additional_filters = update_operation.additional_filters
  161. description = update_operation.description
  162. # Build the where conditions
  163. where_conditions = [
  164. Provider.tenant_id == filters.tenant_id,
  165. Provider.provider_name == filters.provider_name,
  166. ]
  167. # Add additional filters if specified
  168. if filters.provider_type is not None:
  169. where_conditions.append(Provider.provider_type == filters.provider_type)
  170. if filters.quota_type is not None:
  171. where_conditions.append(Provider.quota_type == filters.quota_type)
  172. if additional_filters.quota_limit_check:
  173. where_conditions.append(Provider.quota_limit > Provider.quota_used)
  174. # Prepare values dict for SQLAlchemy update
  175. update_values = {}
  176. # updateing to `last_used` is removed due to performance reason.
  177. # ref: https://github.com/langgenius/dify/issues/24526
  178. if values.quota_used is not None:
  179. update_values["quota_used"] = values.quota_used
  180. # Skip the current update operation if no updates are required.
  181. if not update_values:
  182. continue
  183. # Build and execute the update statement
  184. stmt = update(Provider).where(*where_conditions).values(**update_values)
  185. result = session.execute(stmt)
  186. rows_affected = result.rowcount
  187. logger.debug(
  188. "Provider update (%s): %s rows affected. Filters: %s, Values: %s",
  189. description,
  190. rows_affected,
  191. filters.model_dump(),
  192. update_values,
  193. )
  194. # If no rows were affected for quota updates, log a warning
  195. if rows_affected == 0 and description == "quota_deduction_update":
  196. logger.warning(
  197. "No Provider rows updated for quota deduction. "
  198. "This may indicate quota limit exceeded or provider not found. "
  199. "Filters: %s",
  200. filters.model_dump(),
  201. )
  202. logger.debug("Successfully processed %s Provider updates", len(updates_to_perform))