您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

update_provider_when_message_created.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import logging
  2. import time as time_module
  3. from datetime import datetime
  4. from typing import Any, Optional
  5. from pydantic import BaseModel
  6. from sqlalchemy import update
  7. from sqlalchemy.orm import Session
  8. from configs import dify_config
  9. from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
  10. from core.entities.provider_entities import QuotaUnit, SystemConfiguration
  11. from core.plugin.entities.plugin import ModelProviderID
  12. from events.message_event import message_was_created
  13. from extensions.ext_database import db
  14. from extensions.ext_redis import redis_client, redis_fallback
  15. from libs import datetime_utils
  16. from models.model import Message
  17. from models.provider import Provider, ProviderType
  18. logger = logging.getLogger(__name__)
  19. # Redis cache key prefix for provider last used timestamps
  20. _PROVIDER_LAST_USED_CACHE_PREFIX = "provider:last_used"
  21. # Default TTL for cache entries (10 minutes)
  22. _CACHE_TTL_SECONDS = 600
  23. LAST_USED_UPDATE_WINDOW_SECONDS = 60 * 5
  24. def _get_provider_cache_key(tenant_id: str, provider_name: str) -> str:
  25. """Generate Redis cache key for provider last used timestamp."""
  26. return f"{_PROVIDER_LAST_USED_CACHE_PREFIX}:{tenant_id}:{provider_name}"
  27. @redis_fallback(default_return=None)
  28. def _get_last_update_timestamp(cache_key: str) -> Optional[datetime]:
  29. """Get last update timestamp from Redis cache."""
  30. timestamp_str = redis_client.get(cache_key)
  31. if timestamp_str:
  32. return datetime.fromtimestamp(float(timestamp_str.decode("utf-8")))
  33. return None
  34. @redis_fallback()
  35. def _set_last_update_timestamp(cache_key: str, timestamp: datetime) -> None:
  36. """Set last update timestamp in Redis cache with TTL."""
  37. redis_client.setex(cache_key, _CACHE_TTL_SECONDS, str(timestamp.timestamp()))
  38. class _ProviderUpdateFilters(BaseModel):
  39. """Filters for identifying Provider records to update."""
  40. tenant_id: str
  41. provider_name: str
  42. provider_type: Optional[str] = None
  43. quota_type: Optional[str] = None
  44. class _ProviderUpdateAdditionalFilters(BaseModel):
  45. """Additional filters for Provider updates."""
  46. quota_limit_check: bool = False
  47. class _ProviderUpdateValues(BaseModel):
  48. """Values to update in Provider records."""
  49. last_used: Optional[datetime] = None
  50. quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
  51. class _ProviderUpdateOperation(BaseModel):
  52. """A single Provider update operation."""
  53. filters: _ProviderUpdateFilters
  54. values: _ProviderUpdateValues
  55. additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
  56. description: str = "unknown"
  57. @message_was_created.connect
  58. def handle(sender: Message, **kwargs):
  59. """
  60. Consolidated handler for Provider updates when a message is created.
  61. This handler replaces both:
  62. - update_provider_last_used_at_when_message_created
  63. - deduct_quota_when_message_created
  64. By performing all Provider updates in a single transaction, we ensure
  65. consistency and efficiency when updating Provider records.
  66. """
  67. message = sender
  68. application_generate_entity = kwargs.get("application_generate_entity")
  69. if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
  70. return
  71. tenant_id = application_generate_entity.app_config.tenant_id
  72. provider_name = application_generate_entity.model_conf.provider
  73. current_time = datetime_utils.naive_utc_now()
  74. # Prepare updates for both scenarios
  75. updates_to_perform: list[_ProviderUpdateOperation] = []
  76. # 1. Always update last_used for the provider
  77. basic_update = _ProviderUpdateOperation(
  78. filters=_ProviderUpdateFilters(
  79. tenant_id=tenant_id,
  80. provider_name=provider_name,
  81. ),
  82. values=_ProviderUpdateValues(last_used=current_time),
  83. description="basic_last_used_update",
  84. )
  85. logger.info("provider used, tenant_id=%s, provider_name=%s", tenant_id, provider_name)
  86. updates_to_perform.append(basic_update)
  87. # 2. Check if we need to deduct quota (system provider only)
  88. model_config = application_generate_entity.model_conf
  89. provider_model_bundle = model_config.provider_model_bundle
  90. provider_configuration = provider_model_bundle.configuration
  91. if (
  92. provider_configuration.using_provider_type == ProviderType.SYSTEM
  93. and provider_configuration.system_configuration
  94. and provider_configuration.system_configuration.current_quota_type is not None
  95. ):
  96. system_configuration = provider_configuration.system_configuration
  97. # Calculate quota usage
  98. used_quota = _calculate_quota_usage(
  99. message=message,
  100. system_configuration=system_configuration,
  101. model_name=model_config.model,
  102. )
  103. if used_quota is not None:
  104. quota_update = _ProviderUpdateOperation(
  105. filters=_ProviderUpdateFilters(
  106. tenant_id=tenant_id,
  107. provider_name=ModelProviderID(model_config.provider).provider_name,
  108. provider_type=ProviderType.SYSTEM.value,
  109. quota_type=provider_configuration.system_configuration.current_quota_type.value,
  110. ),
  111. values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
  112. additional_filters=_ProviderUpdateAdditionalFilters(
  113. quota_limit_check=True # Provider.quota_limit > Provider.quota_used
  114. ),
  115. description="quota_deduction_update",
  116. )
  117. updates_to_perform.append(quota_update)
  118. # Execute all updates
  119. start_time = time_module.perf_counter()
  120. try:
  121. _execute_provider_updates(updates_to_perform)
  122. # Log successful completion with timing
  123. duration = time_module.perf_counter() - start_time
  124. logger.info(
  125. "Provider updates completed successfully. Updates: %s, Duration: %s s, Tenant: %s, Provider: %s",
  126. len(updates_to_perform),
  127. duration,
  128. tenant_id,
  129. provider_name,
  130. )
  131. except Exception:
  132. # Log failure with timing and context
  133. duration = time_module.perf_counter() - start_time
  134. logger.exception(
  135. "Provider updates failed after %s s. Updates: %s, Tenant: %s, Provider: %s",
  136. duration,
  137. len(updates_to_perform),
  138. tenant_id,
  139. provider_name,
  140. )
  141. raise
  142. def _calculate_quota_usage(
  143. *, message: Message, system_configuration: SystemConfiguration, model_name: str
  144. ) -> Optional[int]:
  145. """Calculate quota usage based on message tokens and quota type."""
  146. quota_unit = None
  147. for quota_configuration in system_configuration.quota_configurations:
  148. if quota_configuration.quota_type == system_configuration.current_quota_type:
  149. quota_unit = quota_configuration.quota_unit
  150. if quota_configuration.quota_limit == -1:
  151. return None
  152. break
  153. if quota_unit is None:
  154. return None
  155. try:
  156. if quota_unit == QuotaUnit.TOKENS:
  157. tokens = message.message_tokens + message.answer_tokens
  158. return tokens
  159. if quota_unit == QuotaUnit.CREDITS:
  160. tokens = dify_config.get_model_credits(model_name)
  161. return tokens
  162. elif quota_unit == QuotaUnit.TIMES:
  163. return 1
  164. return None
  165. except Exception:
  166. logger.exception("Failed to calculate quota usage")
  167. return None
  168. def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
  169. """Execute all Provider updates in a single transaction."""
  170. if not updates_to_perform:
  171. return
  172. updates_to_perform = sorted(updates_to_perform, key=lambda i: (i.filters.tenant_id, i.filters.provider_name))
  173. # Use SQLAlchemy's context manager for transaction management
  174. # This automatically handles commit/rollback
  175. with Session(db.engine) as session, session.begin():
  176. # Use a single transaction for all updates
  177. for update_operation in updates_to_perform:
  178. filters = update_operation.filters
  179. values = update_operation.values
  180. additional_filters = update_operation.additional_filters
  181. description = update_operation.description
  182. # Build the where conditions
  183. where_conditions = [
  184. Provider.tenant_id == filters.tenant_id,
  185. Provider.provider_name == filters.provider_name,
  186. ]
  187. # Add additional filters if specified
  188. if filters.provider_type is not None:
  189. where_conditions.append(Provider.provider_type == filters.provider_type)
  190. if filters.quota_type is not None:
  191. where_conditions.append(Provider.quota_type == filters.quota_type)
  192. if additional_filters.quota_limit_check:
  193. where_conditions.append(Provider.quota_limit > Provider.quota_used)
  194. # Prepare values dict for SQLAlchemy update
  195. update_values = {}
  196. # NOTE: For frequently used providers under high load, this implementation may experience
  197. # race conditions or update contention despite the time-window optimization:
  198. # 1. Multiple concurrent requests might check the same cache key simultaneously
  199. # 2. Redis cache operations are not atomic with database updates
  200. # 3. Heavy providers could still face database lock contention during peak usage
  201. # The current implementation is acceptable for most scenarios, but future optimization
  202. # considerations could include: batched updates, or async processing.
  203. if values.last_used is not None:
  204. cache_key = _get_provider_cache_key(filters.tenant_id, filters.provider_name)
  205. now = datetime_utils.naive_utc_now()
  206. last_update = _get_last_update_timestamp(cache_key)
  207. if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS:
  208. update_values["last_used"] = values.last_used
  209. _set_last_update_timestamp(cache_key, now)
  210. if values.quota_used is not None:
  211. update_values["quota_used"] = values.quota_used
  212. # Skip the current update operation if no updates are required.
  213. if not update_values:
  214. continue
  215. # Build and execute the update statement
  216. stmt = update(Provider).where(*where_conditions).values(**update_values)
  217. result = session.execute(stmt)
  218. rows_affected = result.rowcount
  219. logger.debug(
  220. "Provider update (%s): %s rows affected. Filters: %s, Values: %s",
  221. description,
  222. rows_affected,
  223. filters.model_dump(),
  224. update_values,
  225. )
  226. # If no rows were affected for quota updates, log a warning
  227. if rows_affected == 0 and description == "quota_deduction_update":
  228. logger.warning(
  229. "No Provider rows updated for quota deduction. "
  230. "This may indicate quota limit exceeded or provider not found. "
  231. "Filters: %s",
  232. filters.model_dump(),
  233. )
  234. logger.debug("Successfully processed %s Provider updates", len(updates_to_perform))