| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- import logging
- import time as time_module
- from datetime import datetime
- from typing import Any, Optional
-
- from pydantic import BaseModel
- from sqlalchemy import update
- from sqlalchemy.orm import Session
-
- from configs import dify_config
- from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
- from core.entities.provider_entities import QuotaUnit, SystemConfiguration
- from core.plugin.entities.plugin import ModelProviderID
- from events.message_event import message_was_created
- from extensions.ext_database import db
- from extensions.ext_redis import redis_client, redis_fallback
- from libs import datetime_utils
- from models.model import Message
- from models.provider import Provider, ProviderType
-
- logger = logging.getLogger(__name__)
-
- # Redis cache key prefix for provider last used timestamps
- _PROVIDER_LAST_USED_CACHE_PREFIX = "provider:last_used"
- # Default TTL for cache entries (10 minutes)
- _CACHE_TTL_SECONDS = 600
- LAST_USED_UPDATE_WINDOW_SECONDS = 60 * 5
-
-
- def _get_provider_cache_key(tenant_id: str, provider_name: str) -> str:
- """Generate Redis cache key for provider last used timestamp."""
- return f"{_PROVIDER_LAST_USED_CACHE_PREFIX}:{tenant_id}:{provider_name}"
-
-
- @redis_fallback(default_return=None)
- def _get_last_update_timestamp(cache_key: str) -> Optional[datetime]:
- """Get last update timestamp from Redis cache."""
- timestamp_str = redis_client.get(cache_key)
- if timestamp_str:
- return datetime.fromtimestamp(float(timestamp_str.decode("utf-8")))
- return None
-
-
- @redis_fallback()
- def _set_last_update_timestamp(cache_key: str, timestamp: datetime) -> None:
- """Set last update timestamp in Redis cache with TTL."""
- redis_client.setex(cache_key, _CACHE_TTL_SECONDS, str(timestamp.timestamp()))
-
-
- class _ProviderUpdateFilters(BaseModel):
- """Filters for identifying Provider records to update."""
-
- tenant_id: str
- provider_name: str
- provider_type: Optional[str] = None
- quota_type: Optional[str] = None
-
-
- class _ProviderUpdateAdditionalFilters(BaseModel):
- """Additional filters for Provider updates."""
-
- quota_limit_check: bool = False
-
-
- class _ProviderUpdateValues(BaseModel):
- """Values to update in Provider records."""
-
- last_used: Optional[datetime] = None
- quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
-
-
- class _ProviderUpdateOperation(BaseModel):
- """A single Provider update operation."""
-
- filters: _ProviderUpdateFilters
- values: _ProviderUpdateValues
- additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
- description: str = "unknown"
-
-
- @message_was_created.connect
- def handle(sender: Message, **kwargs):
- """
- Consolidated handler for Provider updates when a message is created.
-
- This handler replaces both:
- - update_provider_last_used_at_when_message_created
- - deduct_quota_when_message_created
-
- By performing all Provider updates in a single transaction, we ensure
- consistency and efficiency when updating Provider records.
- """
- message = sender
- application_generate_entity = kwargs.get("application_generate_entity")
-
- if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
- return
-
- tenant_id = application_generate_entity.app_config.tenant_id
- provider_name = application_generate_entity.model_conf.provider
- current_time = datetime_utils.naive_utc_now()
-
- # Prepare updates for both scenarios
- updates_to_perform: list[_ProviderUpdateOperation] = []
-
- # 1. Always update last_used for the provider
- basic_update = _ProviderUpdateOperation(
- filters=_ProviderUpdateFilters(
- tenant_id=tenant_id,
- provider_name=provider_name,
- ),
- values=_ProviderUpdateValues(last_used=current_time),
- description="basic_last_used_update",
- )
- logger.info("provider used, tenant_id=%s, provider_name=%s", tenant_id, provider_name)
- updates_to_perform.append(basic_update)
-
- # 2. Check if we need to deduct quota (system provider only)
- model_config = application_generate_entity.model_conf
- provider_model_bundle = model_config.provider_model_bundle
- provider_configuration = provider_model_bundle.configuration
-
- if (
- provider_configuration.using_provider_type == ProviderType.SYSTEM
- and provider_configuration.system_configuration
- and provider_configuration.system_configuration.current_quota_type is not None
- ):
- system_configuration = provider_configuration.system_configuration
-
- # Calculate quota usage
- used_quota = _calculate_quota_usage(
- message=message,
- system_configuration=system_configuration,
- model_name=model_config.model,
- )
-
- if used_quota is not None:
- quota_update = _ProviderUpdateOperation(
- filters=_ProviderUpdateFilters(
- tenant_id=tenant_id,
- provider_name=ModelProviderID(model_config.provider).provider_name,
- provider_type=ProviderType.SYSTEM.value,
- quota_type=provider_configuration.system_configuration.current_quota_type.value,
- ),
- values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
- additional_filters=_ProviderUpdateAdditionalFilters(
- quota_limit_check=True # Provider.quota_limit > Provider.quota_used
- ),
- description="quota_deduction_update",
- )
- updates_to_perform.append(quota_update)
-
- # Execute all updates
- start_time = time_module.perf_counter()
- try:
- _execute_provider_updates(updates_to_perform)
-
- # Log successful completion with timing
- duration = time_module.perf_counter() - start_time
-
- logger.info(
- "Provider updates completed successfully. Updates: %s, Duration: %s s, Tenant: %s, Provider: %s",
- len(updates_to_perform),
- duration,
- tenant_id,
- provider_name,
- )
-
- except Exception:
- # Log failure with timing and context
- duration = time_module.perf_counter() - start_time
-
- logger.exception(
- "Provider updates failed after %s s. Updates: %s, Tenant: %s, Provider: %s",
- duration,
- len(updates_to_perform),
- tenant_id,
- provider_name,
- )
- raise
-
-
- def _calculate_quota_usage(
- *, message: Message, system_configuration: SystemConfiguration, model_name: str
- ) -> Optional[int]:
- """Calculate quota usage based on message tokens and quota type."""
- quota_unit = None
- for quota_configuration in system_configuration.quota_configurations:
- if quota_configuration.quota_type == system_configuration.current_quota_type:
- quota_unit = quota_configuration.quota_unit
- if quota_configuration.quota_limit == -1:
- return None
- break
- if quota_unit is None:
- return None
-
- try:
- if quota_unit == QuotaUnit.TOKENS:
- tokens = message.message_tokens + message.answer_tokens
- return tokens
- if quota_unit == QuotaUnit.CREDITS:
- tokens = dify_config.get_model_credits(model_name)
- return tokens
- elif quota_unit == QuotaUnit.TIMES:
- return 1
- return None
- except Exception:
- logger.exception("Failed to calculate quota usage")
- return None
-
-
- def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
- """Execute all Provider updates in a single transaction."""
- if not updates_to_perform:
- return
-
- updates_to_perform = sorted(updates_to_perform, key=lambda i: (i.filters.tenant_id, i.filters.provider_name))
-
- # Use SQLAlchemy's context manager for transaction management
- # This automatically handles commit/rollback
- with Session(db.engine) as session, session.begin():
- # Use a single transaction for all updates
- for update_operation in updates_to_perform:
- filters = update_operation.filters
- values = update_operation.values
- additional_filters = update_operation.additional_filters
- description = update_operation.description
-
- # Build the where conditions
- where_conditions = [
- Provider.tenant_id == filters.tenant_id,
- Provider.provider_name == filters.provider_name,
- ]
-
- # Add additional filters if specified
- if filters.provider_type is not None:
- where_conditions.append(Provider.provider_type == filters.provider_type)
- if filters.quota_type is not None:
- where_conditions.append(Provider.quota_type == filters.quota_type)
- if additional_filters.quota_limit_check:
- where_conditions.append(Provider.quota_limit > Provider.quota_used)
-
- # Prepare values dict for SQLAlchemy update
- update_values = {}
-
- # NOTE: For frequently used providers under high load, this implementation may experience
- # race conditions or update contention despite the time-window optimization:
- # 1. Multiple concurrent requests might check the same cache key simultaneously
- # 2. Redis cache operations are not atomic with database updates
- # 3. Heavy providers could still face database lock contention during peak usage
- # The current implementation is acceptable for most scenarios, but future optimization
- # considerations could include: batched updates, or async processing.
- if values.last_used is not None:
- cache_key = _get_provider_cache_key(filters.tenant_id, filters.provider_name)
- now = datetime_utils.naive_utc_now()
- last_update = _get_last_update_timestamp(cache_key)
-
- if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS:
- update_values["last_used"] = values.last_used
- _set_last_update_timestamp(cache_key, now)
-
- if values.quota_used is not None:
- update_values["quota_used"] = values.quota_used
- # Skip the current update operation if no updates are required.
- if not update_values:
- continue
-
- # Build and execute the update statement
- stmt = update(Provider).where(*where_conditions).values(**update_values)
- result = session.execute(stmt)
- rows_affected = result.rowcount
-
- logger.debug(
- "Provider update (%s): %s rows affected. Filters: %s, Values: %s",
- description,
- rows_affected,
- filters.model_dump(),
- update_values,
- )
-
- # If no rows were affected for quota updates, log a warning
- if rows_affected == 0 and description == "quota_deduction_update":
- logger.warning(
- "No Provider rows updated for quota deduction. "
- "This may indicate quota limit exceeded or provider not found. "
- "Filters: %s",
- filters.model_dump(),
- )
-
- logger.debug("Successfully processed %s Provider updates", len(updates_to_perform))
|