Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: takatost <takatost@users.noreply.github.com>tags/0.3.33
| @@ -124,5 +124,11 @@ HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000 | |||
| HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20 | |||
| HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100 | |||
| # Stripe configuration | |||
| STRIPE_API_KEY= | |||
| STRIPE_WEBHOOK_SECRET= | |||
| STRIPE_WEBHOOK_SECRET= | |||
| # Billing configuration | |||
| BILLING_API_URL=http://127.0.0.1:8000/v1 | |||
| BILLING_API_SECRET_KEY= | |||
| STRIPE_WEBHOOK_BILLING_SECRET= | |||
| @@ -28,3 +28,5 @@ from .universal_chat import chat, conversation, message, parameter, audio | |||
| # Import webhook controllers | |||
| from .webhook import stripe | |||
| from .billing import billing | |||
| @@ -0,0 +1,85 @@ | |||
| import stripe | |||
| import os | |||
| from flask_restful import Resource, reqparse | |||
| from flask_login import current_user | |||
| from flask import current_app, request | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from controllers.console.wraps import only_edition_cloud | |||
| from libs.login import login_required | |||
| from services.billing_service import BillingService | |||
| class BillingInfo(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| edition = current_app.config['EDITION'] | |||
| if edition != 'CLOUD': | |||
| return {"enabled": False} | |||
| return BillingService.get_info(current_user.current_tenant_id) | |||
| class Subscription(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @only_edition_cloud | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team']) | |||
| parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) | |||
| args = parser.parse_args() | |||
| return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id) | |||
| class Invoices(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @only_edition_cloud | |||
| def get(self): | |||
| return BillingService.get_invoices(current_user.email) | |||
| class StripeBillingWebhook(Resource): | |||
| @setup_required | |||
| @only_edition_cloud | |||
| def post(self): | |||
| payload = request.data | |||
| sig_header = request.headers.get('STRIPE_SIGNATURE') | |||
| webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET') | |||
| try: | |||
| event = stripe.Webhook.construct_event( | |||
| payload, sig_header, webhook_secret | |||
| ) | |||
| except ValueError as e: | |||
| # Invalid payload | |||
| return 'Invalid payload', 400 | |||
| except stripe.error.SignatureVerificationError as e: | |||
| # Invalid signature | |||
| return 'Invalid signature', 400 | |||
| BillingService.process_event(event) | |||
| return 'success', 200 | |||
| api.add_resource(BillingInfo, '/billing/info') | |||
| api.add_resource(Subscription, '/billing/subscription') | |||
| api.add_resource(Invoices, '/billing/invoices') | |||
| api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe') | |||
| @@ -493,3 +493,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>') | |||
| api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | |||
| api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') | |||
| api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') | |||
| @@ -0,0 +1,55 @@ | |||
| import os | |||
| import requests | |||
| from services.dataset_service import DatasetService | |||
| class BillingService: | |||
| base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') | |||
| secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') | |||
| @classmethod | |||
| def get_info(cls, tenant_id: str): | |||
| params = {'tenant_id': tenant_id} | |||
| billing_info = cls._send_request('GET', '/info', params=params) | |||
| vector_size = DatasetService.get_tenant_datasets_usage(tenant_id) / 1024 | |||
| billing_info['vector_space']['size'] = int(vector_size) | |||
| return billing_info | |||
| @classmethod | |||
| def get_subscription(cls, plan: str, interval: str, prefilled_email: str = '', user_name: str = '', tenant_id: str = ''): | |||
| params = { | |||
| 'plan': plan, | |||
| 'interval': interval, | |||
| 'prefilled_email': prefilled_email, | |||
| 'user_name': user_name, | |||
| 'tenant_id': tenant_id | |||
| } | |||
| return cls._send_request('GET', '/subscription', params=params) | |||
| @classmethod | |||
| def get_invoices(cls, prefilled_email: str = ''): | |||
| params = {'prefilled_email': prefilled_email} | |||
| return cls._send_request('GET', '/invoices', params=params) | |||
| @classmethod | |||
| def _send_request(cls, method, endpoint, json=None, params=None): | |||
| headers = { | |||
| "Content-Type": "application/json", | |||
| "Billing-Api-Secret-Key": cls.secret_key | |||
| } | |||
| url = f"{cls.base_url}{endpoint}" | |||
| response = requests.request(method, url, json=json, params=params, headers=headers) | |||
| return response.json() | |||
| @classmethod | |||
| def process_event(cls, event: dict): | |||
| json = { | |||
| "content": event, | |||
| } | |||
| return cls._send_request('POST', '/webhook/stripe', json=json) | |||
| @@ -227,6 +227,36 @@ class DatasetService: | |||
| return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \ | |||
| .order_by(db.desc(AppDatasetJoin.created_at)).all() | |||
| @staticmethod | |||
| def get_tenant_datasets_usage(tenant_id): | |||
| # get the high_quality datasets | |||
| dataset_ids = db.session.query(Dataset.id).filter(Dataset.indexing_technique == 'high_quality', | |||
| Dataset.tenant_id == tenant_id).all() | |||
| if not dataset_ids: | |||
| return 0 | |||
| dataset_ids = [result[0] for result in dataset_ids] | |||
| document_ids = db.session.query(Document.id).filter(Document.dataset_id.in_(dataset_ids), | |||
| Document.tenant_id == tenant_id, | |||
| Document.completed_at.isnot(None), | |||
| Document.enabled == True, | |||
| Document.archived == False | |||
| ).all() | |||
| if not document_ids: | |||
| return 0 | |||
| document_ids = [result[0] for result in document_ids] | |||
| document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids), | |||
| DocumentSegment.tenant_id == tenant_id, | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.enabled == True, | |||
| ).all() | |||
| if not document_segments: | |||
| return 0 | |||
| total_words_size = sum(document_segment.word_count * 3 for document_segment in document_segments) | |||
| total_vector_size = 1536 * 4 * len(document_segments) | |||
| return total_words_size + total_vector_size | |||
| class DocumentService: | |||
| DEFAULT_RULES = { | |||
| @@ -488,7 +518,8 @@ class DocumentService: | |||
| 'score_threshold_enabled': False | |||
| } | |||
| dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model | |||
| dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get( | |||
| 'retrieval_model') else default_retrieval_model | |||
| documents = [] | |||
| batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) | |||