| @@ -10,7 +10,12 @@ from controllers.console import api | |||
| from controllers.console.apikey import api_key_fields, api_key_list | |||
| from controllers.console.app.error import ProviderNotInitializeError | |||
| from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError | |||
| from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| cloud_edition_billing_rate_limit_check, | |||
| enterprise_license_required, | |||
| setup_required, | |||
| ) | |||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| @@ -93,6 +98,7 @@ class DatasetListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| @@ -207,6 +213,7 @@ class DatasetApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def patch(self, dataset_id): | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| @@ -310,6 +317,7 @@ class DatasetApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def delete(self, dataset_id): | |||
| dataset_id_str = str(dataset_id) | |||
| @@ -27,6 +27,7 @@ from controllers.console.datasets.error import ( | |||
| ) | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| cloud_edition_billing_rate_limit_check, | |||
| cloud_edition_billing_resource_check, | |||
| setup_required, | |||
| ) | |||
| @@ -230,6 +231,7 @@ class DatasetDocumentListApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(documents_and_batch_fields) | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def post(self, dataset_id): | |||
| dataset_id = str(dataset_id) | |||
| @@ -285,6 +287,7 @@ class DatasetDocumentListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def delete(self, dataset_id): | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| @@ -308,6 +311,7 @@ class DatasetInitApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(dataset_and_document_fields) | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def post(self): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| @@ -680,6 +684,7 @@ class DocumentProcessingApi(DocumentResource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def patch(self, dataset_id, document_id, action): | |||
| dataset_id = str(dataset_id) | |||
| document_id = str(document_id) | |||
| @@ -716,6 +721,7 @@ class DocumentDeleteApi(DocumentResource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def delete(self, dataset_id, document_id): | |||
| dataset_id = str(dataset_id) | |||
| document_id = str(document_id) | |||
| @@ -784,6 +790,7 @@ class DocumentStatusApi(DocumentResource): | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def patch(self, dataset_id, action): | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| @@ -879,6 +886,7 @@ class DocumentPauseApi(DocumentResource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def patch(self, dataset_id, document_id): | |||
| """pause document.""" | |||
| dataset_id = str(dataset_id) | |||
| @@ -911,6 +919,7 @@ class DocumentRecoverApi(DocumentResource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def patch(self, dataset_id, document_id): | |||
| """recover document.""" | |||
| dataset_id = str(dataset_id) | |||
| @@ -940,6 +949,7 @@ class DocumentRetryApi(DocumentResource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def post(self, dataset_id): | |||
| """retry document.""" | |||
| @@ -19,6 +19,7 @@ from controllers.console.datasets.error import ( | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| cloud_edition_billing_knowledge_limit_check, | |||
| cloud_edition_billing_rate_limit_check, | |||
| cloud_edition_billing_resource_check, | |||
| setup_required, | |||
| ) | |||
| @@ -106,6 +107,7 @@ class DatasetDocumentSegmentListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def delete(self, dataset_id, document_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| @@ -137,6 +139,7 @@ class DatasetDocumentSegmentApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def patch(self, dataset_id, document_id, action): | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| @@ -192,6 +195,7 @@ class DatasetDocumentSegmentAddApi(Resource): | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_knowledge_limit_check("add_segment") | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def post(self, dataset_id, document_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| @@ -242,6 +246,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def patch(self, dataset_id, document_id, segment_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| @@ -302,6 +307,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def delete(self, dataset_id, document_id, segment_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| @@ -339,6 +345,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_knowledge_limit_check("add_segment") | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def post(self, dataset_id, document_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| @@ -405,6 +412,7 @@ class ChildChunkAddApi(Resource): | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_knowledge_limit_check("add_segment") | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def post(self, dataset_id, document_id, segment_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| @@ -503,6 +511,7 @@ class ChildChunkAddApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def patch(self, dataset_id, document_id, segment_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| @@ -546,6 +555,7 @@ class ChildChunkUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def delete(self, dataset_id, document_id, segment_id, child_chunk_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| @@ -590,6 +600,7 @@ class ChildChunkUpdateApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def patch(self, dataset_id, document_id, segment_id, child_chunk_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| @@ -2,7 +2,11 @@ from flask_restful import Resource # type: ignore | |||
| from controllers.console import api | |||
| from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| cloud_edition_billing_rate_limit_check, | |||
| setup_required, | |||
| ) | |||
| from libs.login import login_required | |||
| @@ -10,6 +14,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_rate_limit_check("knowledge") | |||
| def post(self, dataset_id): | |||
| dataset_id_str = str(dataset_id) | |||
| @@ -1,5 +1,6 @@ | |||
| import json | |||
| import os | |||
| import time | |||
| from functools import wraps | |||
| from flask import abort, request | |||
| @@ -7,6 +8,7 @@ from flask_login import current_user # type: ignore | |||
| from configs import dify_config | |||
| from controllers.console.workspace.error import AccountNotInitializedError | |||
| from extensions.ext_redis import redis_client | |||
| from models.model import DifySetup | |||
| from services.feature_service import FeatureService, LicenseStatus | |||
| from services.operation_service import OperationService | |||
| @@ -66,7 +68,9 @@ def cloud_edition_billing_resource_check(resource: str): | |||
| elif resource == "apps" and 0 < apps.limit <= apps.size: | |||
| abort(403, "The number of apps has reached the limit of your subscription.") | |||
| elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: | |||
| abort(403, "The capacity of the vector space has reached the limit of your subscription.") | |||
| abort( | |||
| 403, "The capacity of the knowledge storage space has reached the limit of your subscription." | |||
| ) | |||
| elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: | |||
| # The api of file upload is used in the multiple places, | |||
| # so we need to check the source of the request from datasets | |||
| @@ -111,6 +115,33 @@ def cloud_edition_billing_knowledge_limit_check(resource: str): | |||
| return interceptor | |||
| def cloud_edition_billing_rate_limit_check(resource: str): | |||
| def interceptor(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if resource == "knowledge": | |||
| knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) | |||
| if knowledge_rate_limit.enabled: | |||
| current_time = int(time.time() * 1000) | |||
| key = f"rate_limit_{current_user.current_tenant_id}" | |||
| redis_client.zadd(key, {current_time: current_time}) | |||
| redis_client.zremrangebyscore(key, 0, current_time - 60000) | |||
| request_count = redis_client.zcard(key) | |||
| if request_count > knowledge_rate_limit.limit: | |||
| abort( | |||
| 403, "Sorry, you have reached the knowledge base request rate limit of your subscription." | |||
| ) | |||
| return view(*args, **kwargs) | |||
| return decorated | |||
| return interceptor | |||
| def cloud_utm_record(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| @@ -1,3 +1,4 @@ | |||
| import time | |||
| from collections.abc import Callable | |||
| from datetime import UTC, datetime, timedelta | |||
| from enum import Enum | |||
| @@ -13,6 +14,7 @@ from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden, Unauthorized | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from libs.login import _get_user | |||
| from models.account import Account, Tenant, TenantAccountJoin, TenantStatus | |||
| from models.model import ApiToken, App, EndUser | |||
| @@ -139,6 +141,35 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s | |||
| return interceptor | |||
| def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): | |||
| def interceptor(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| api_token = validate_and_get_api_token(api_token_type) | |||
| if resource == "knowledge": | |||
| knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id) | |||
| if knowledge_rate_limit.enabled: | |||
| current_time = int(time.time() * 1000) | |||
| key = f"rate_limit_{api_token.tenant_id}" | |||
| redis_client.zadd(key, {current_time: current_time}) | |||
| redis_client.zremrangebyscore(key, 0, current_time - 60000) | |||
| request_count = redis_client.zcard(key) | |||
| if request_count > knowledge_rate_limit.limit: | |||
| raise Forbidden( | |||
| "Sorry, you have reached the knowledge base request rate limit of your subscription." | |||
| ) | |||
| return view(*args, **kwargs) | |||
| return decorated | |||
| return interceptor | |||
| def validate_dataset_token(view=None): | |||
| def decorator(view): | |||
| @wraps(view) | |||
| @@ -1,4 +1,5 @@ | |||
| import logging | |||
| import time | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, cast | |||
| @@ -19,8 +20,10 @@ from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset, Document | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from services.feature_service import FeatureService | |||
| from .entities import KnowledgeRetrievalNodeData | |||
| from .exc import ( | |||
| @@ -61,6 +64,23 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." | |||
| ) | |||
| # check rate limit | |||
| if self.tenant_id: | |||
| knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) | |||
| if knowledge_rate_limit.enabled: | |||
| current_time = int(time.time() * 1000) | |||
| key = f"rate_limit_{self.tenant_id}" | |||
| redis_client.zadd(key, {current_time: current_time}) | |||
| redis_client.zremrangebyscore(key, 0, current_time - 60000) | |||
| request_count = redis_client.zcard(key) | |||
| if request_count > knowledge_rate_limit.limit: | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| inputs=variables, | |||
| error="Sorry, you have reached the knowledge base request rate limit of your subscription.", | |||
| error_type="RateLimitExceeded", | |||
| ) | |||
| # retrieve knowledge | |||
| try: | |||
| results = self._fetch_dataset_retriever(node_data=self.node_data, query=query) | |||
| @@ -19,6 +19,14 @@ class BillingService: | |||
| billing_info = cls._send_request("GET", "/subscription/info", params=params) | |||
| return billing_info | |||
| @classmethod | |||
| def get_knowledge_rate_limit(cls, tenant_id: str): | |||
| params = {"tenant_id": tenant_id} | |||
| knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params) | |||
| return knowledge_rate_limit.get("limit", 10) | |||
| @classmethod | |||
| def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""): | |||
| params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id} | |||
| @@ -41,6 +41,7 @@ class FeatureModel(BaseModel): | |||
| members: LimitationModel = LimitationModel(size=0, limit=1) | |||
| apps: LimitationModel = LimitationModel(size=0, limit=10) | |||
| vector_space: LimitationModel = LimitationModel(size=0, limit=5) | |||
| knowledge_rate_limit: int = 10 | |||
| annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) | |||
| documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) | |||
| docs_processing: str = "standard" | |||
| @@ -52,6 +53,11 @@ class FeatureModel(BaseModel): | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| class KnowledgeRateLimitModel(BaseModel): | |||
| enabled: bool = False | |||
| limit: int = 10 | |||
| class SystemFeatureModel(BaseModel): | |||
| sso_enforced_for_signin: bool = False | |||
| sso_enforced_for_signin_protocol: str = "" | |||
| @@ -79,6 +85,14 @@ class FeatureService: | |||
| return features | |||
| @classmethod | |||
| def get_knowledge_rate_limit(cls, tenant_id: str): | |||
| knowledge_rate_limit = KnowledgeRateLimitModel() | |||
| if dify_config.BILLING_ENABLED and tenant_id: | |||
| knowledge_rate_limit.enabled = True | |||
| knowledge_rate_limit.limit = BillingService.get_knowledge_rate_limit(tenant_id) | |||
| return knowledge_rate_limit | |||
| @classmethod | |||
| def get_system_features(cls) -> SystemFeatureModel: | |||
| system_features = SystemFeatureModel() | |||
| @@ -144,6 +158,9 @@ class FeatureService: | |||
| if "model_load_balancing_enabled" in billing_info: | |||
| features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] | |||
| if "knowledge_rate_limit" in billing_info: | |||
| features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"] | |||
| @classmethod | |||
| def _fulfill_params_from_enterprise(cls, features): | |||
| enterprise_info = EnterpriseService.get_info() | |||