Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: takatost <takatost@users.noreply.github.com>tags/0.3.33
| HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20 | HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20 | ||||
| HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100 | HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100 | ||||
| # Stripe configuration | |||||
| STRIPE_API_KEY= | STRIPE_API_KEY= | ||||
| STRIPE_WEBHOOK_SECRET= | |||||
| STRIPE_WEBHOOK_SECRET= | |||||
| # Billing configuration | |||||
| BILLING_API_URL=http://127.0.0.1:8000/v1 | |||||
| BILLING_API_SECRET_KEY= | |||||
| STRIPE_WEBHOOK_BILLING_SECRET= |
| # Import webhook controllers | # Import webhook controllers | ||||
| from .webhook import stripe | from .webhook import stripe | ||||
| from .billing import billing |
| import stripe | |||||
| import os | |||||
| from flask_restful import Resource, reqparse | |||||
| from flask_login import current_user | |||||
| from flask import current_app, request | |||||
| from controllers.console import api | |||||
| from controllers.console.setup import setup_required | |||||
| from controllers.console.wraps import account_initialization_required | |||||
| from controllers.console.wraps import only_edition_cloud | |||||
| from libs.login import login_required | |||||
| from services.billing_service import BillingService | |||||
| class BillingInfo(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self): | |||||
| edition = current_app.config['EDITION'] | |||||
| if edition != 'CLOUD': | |||||
| return {"enabled": False} | |||||
| return BillingService.get_info(current_user.current_tenant_id) | |||||
| class Subscription(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @only_edition_cloud | |||||
| def get(self): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team']) | |||||
| parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) | |||||
| args = parser.parse_args() | |||||
| return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id) | |||||
| class Invoices(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @only_edition_cloud | |||||
| def get(self): | |||||
| return BillingService.get_invoices(current_user.email) | |||||
| class StripeBillingWebhook(Resource): | |||||
| @setup_required | |||||
| @only_edition_cloud | |||||
| def post(self): | |||||
| payload = request.data | |||||
| sig_header = request.headers.get('STRIPE_SIGNATURE') | |||||
| webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET') | |||||
| try: | |||||
| event = stripe.Webhook.construct_event( | |||||
| payload, sig_header, webhook_secret | |||||
| ) | |||||
| except ValueError as e: | |||||
| # Invalid payload | |||||
| return 'Invalid payload', 400 | |||||
| except stripe.error.SignatureVerificationError as e: | |||||
| # Invalid signature | |||||
| return 'Invalid signature', 400 | |||||
| BillingService.process_event(event) | |||||
| return 'success', 200 | |||||
| api.add_resource(BillingInfo, '/billing/info') | |||||
| api.add_resource(Subscription, '/billing/subscription') | |||||
| api.add_resource(Invoices, '/billing/invoices') | |||||
| api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe') |
| api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | ||||
| api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') | api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') | ||||
| api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') | api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') | ||||
| import os | |||||
| import requests | |||||
| from services.dataset_service import DatasetService | |||||
| class BillingService: | |||||
| base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') | |||||
| secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') | |||||
| @classmethod | |||||
| def get_info(cls, tenant_id: str): | |||||
| params = {'tenant_id': tenant_id} | |||||
| billing_info = cls._send_request('GET', '/info', params=params) | |||||
| vector_size = DatasetService.get_tenant_datasets_usage(tenant_id) / 1024 | |||||
| billing_info['vector_space']['size'] = int(vector_size) | |||||
| return billing_info | |||||
| @classmethod | |||||
| def get_subscription(cls, plan: str, interval: str, prefilled_email: str = '', user_name: str = '', tenant_id: str = ''): | |||||
| params = { | |||||
| 'plan': plan, | |||||
| 'interval': interval, | |||||
| 'prefilled_email': prefilled_email, | |||||
| 'user_name': user_name, | |||||
| 'tenant_id': tenant_id | |||||
| } | |||||
| return cls._send_request('GET', '/subscription', params=params) | |||||
| @classmethod | |||||
| def get_invoices(cls, prefilled_email: str = ''): | |||||
| params = {'prefilled_email': prefilled_email} | |||||
| return cls._send_request('GET', '/invoices', params=params) | |||||
| @classmethod | |||||
| def _send_request(cls, method, endpoint, json=None, params=None): | |||||
| headers = { | |||||
| "Content-Type": "application/json", | |||||
| "Billing-Api-Secret-Key": cls.secret_key | |||||
| } | |||||
| url = f"{cls.base_url}{endpoint}" | |||||
| response = requests.request(method, url, json=json, params=params, headers=headers) | |||||
| return response.json() | |||||
| @classmethod | |||||
| def process_event(cls, event: dict): | |||||
| json = { | |||||
| "content": event, | |||||
| } | |||||
| return cls._send_request('POST', '/webhook/stripe', json=json) |
| return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \ | return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \ | ||||
| .order_by(db.desc(AppDatasetJoin.created_at)).all() | .order_by(db.desc(AppDatasetJoin.created_at)).all() | ||||
| @staticmethod | |||||
| def get_tenant_datasets_usage(tenant_id): | |||||
| # get the high_quality datasets | |||||
| dataset_ids = db.session.query(Dataset.id).filter(Dataset.indexing_technique == 'high_quality', | |||||
| Dataset.tenant_id == tenant_id).all() | |||||
| if not dataset_ids: | |||||
| return 0 | |||||
| dataset_ids = [result[0] for result in dataset_ids] | |||||
| document_ids = db.session.query(Document.id).filter(Document.dataset_id.in_(dataset_ids), | |||||
| Document.tenant_id == tenant_id, | |||||
| Document.completed_at.isnot(None), | |||||
| Document.enabled == True, | |||||
| Document.archived == False | |||||
| ).all() | |||||
| if not document_ids: | |||||
| return 0 | |||||
| document_ids = [result[0] for result in document_ids] | |||||
| document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids), | |||||
| DocumentSegment.tenant_id == tenant_id, | |||||
| DocumentSegment.completed_at.isnot(None), | |||||
| DocumentSegment.enabled == True, | |||||
| ).all() | |||||
| if not document_segments: | |||||
| return 0 | |||||
| total_words_size = sum(document_segment.word_count * 3 for document_segment in document_segments) | |||||
| total_vector_size = 1536 * 4 * len(document_segments) | |||||
| return total_words_size + total_vector_size | |||||
| class DocumentService: | class DocumentService: | ||||
| DEFAULT_RULES = { | DEFAULT_RULES = { | ||||
| 'score_threshold_enabled': False | 'score_threshold_enabled': False | ||||
| } | } | ||||
| dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model | |||||
| dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get( | |||||
| 'retrieval_model') else default_retrieval_model | |||||
| documents = [] | documents = [] | ||||
| batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) | batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) |