ソースを参照

Feat/dify billing (#1679)

Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: takatost <takatost@users.noreply.github.com>
tags/0.3.33
Garfield Dai 1年前
コミット
053102f433
コミッターのメールアドレスに関連付けられたアカウントが存在しません

+ 7
- 1
api/.env.example ファイルの表示

@@ -124,5 +124,11 @@ HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100

# Stripe configuration
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
STRIPE_WEBHOOK_SECRET=

# Billing configuration
BILLING_API_URL=http://127.0.0.1:8000/v1
BILLING_API_SECRET_KEY=
STRIPE_WEBHOOK_BILLING_SECRET=

+ 2
- 0
api/controllers/console/__init__.py ファイルの表示

@@ -28,3 +28,5 @@ from .universal_chat import chat, conversation, message, parameter, audio

# Import webhook controllers
from .webhook import stripe

from .billing import billing

+ 0
- 0
api/controllers/console/billing/__init__.py ファイルの表示


+ 85
- 0
api/controllers/console/billing/billing.py ファイルの表示

@@ -0,0 +1,85 @@
import stripe
import os

from flask_restful import Resource, reqparse
from flask_login import current_user
from flask import current_app, request

from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import only_edition_cloud
from libs.login import login_required
from services.billing_service import BillingService


class BillingInfo(Resource):

@setup_required
@login_required
@account_initialization_required
def get(self):

edition = current_app.config['EDITION']
if edition != 'CLOUD':
return {"enabled": False}

return BillingService.get_info(current_user.current_tenant_id)


class Subscription(Resource):

@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):

parser = reqparse.RequestParser()
parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
args = parser.parse_args()

return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id)


class Invoices(Resource):

@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):

return BillingService.get_invoices(current_user.email)


class StripeBillingWebhook(Resource):

@setup_required
@only_edition_cloud
def post(self):
payload = request.data
sig_header = request.headers.get('STRIPE_SIGNATURE')
webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET')

try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
# Invalid payload
return 'Invalid payload', 400
except stripe.error.SignatureVerificationError as e:
# Invalid signature
return 'Invalid signature', 400

BillingService.process_event(event)

return 'success', 200


api.add_resource(BillingInfo, '/billing/info')
api.add_resource(Subscription, '/billing/subscription')
api.add_resource(Invoices, '/billing/invoices')
api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe')

+ 1
- 0
api/controllers/console/datasets/datasets.py ファイルの表示

@@ -493,3 +493,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')


+ 55
- 0
api/services/billing_service.py ファイルの表示

@@ -0,0 +1,55 @@
import os
import requests

from services.dataset_service import DatasetService


class BillingService:
base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')

@classmethod
def get_info(cls, tenant_id: str):
params = {'tenant_id': tenant_id}

billing_info = cls._send_request('GET', '/info', params=params)

vector_size = DatasetService.get_tenant_datasets_usage(tenant_id) / 1024
billing_info['vector_space']['size'] = int(vector_size)

return billing_info

@classmethod
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = '', user_name: str = '', tenant_id: str = ''):
params = {
'plan': plan,
'interval': interval,
'prefilled_email': prefilled_email,
'user_name': user_name,
'tenant_id': tenant_id
}
return cls._send_request('GET', '/subscription', params=params)

@classmethod
def get_invoices(cls, prefilled_email: str = ''):
params = {'prefilled_email': prefilled_email}
return cls._send_request('GET', '/invoices', params=params)

@classmethod
def _send_request(cls, method, endpoint, json=None, params=None):
headers = {
"Content-Type": "application/json",
"Billing-Api-Secret-Key": cls.secret_key
}

url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers)

return response.json()

@classmethod
def process_event(cls, event: dict):
json = {
"content": event,
}
return cls._send_request('POST', '/webhook/stripe', json=json)

+ 32
- 1
api/services/dataset_service.py ファイルの表示

@@ -227,6 +227,36 @@ class DatasetService:
return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
.order_by(db.desc(AppDatasetJoin.created_at)).all()

@staticmethod
def get_tenant_datasets_usage(tenant_id):
# get the high_quality datasets
dataset_ids = db.session.query(Dataset.id).filter(Dataset.indexing_technique == 'high_quality',
Dataset.tenant_id == tenant_id).all()
if not dataset_ids:
return 0
dataset_ids = [result[0] for result in dataset_ids]
document_ids = db.session.query(Document.id).filter(Document.dataset_id.in_(dataset_ids),
Document.tenant_id == tenant_id,
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False
).all()
if not document_ids:
return 0
document_ids = [result[0] for result in document_ids]
document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids),
DocumentSegment.tenant_id == tenant_id,
DocumentSegment.completed_at.isnot(None),
DocumentSegment.enabled == True,
).all()
if not document_segments:
return 0

total_words_size = sum(document_segment.word_count * 3 for document_segment in document_segments)
total_vector_size = 1536 * 4 * len(document_segments)

return total_words_size + total_vector_size


class DocumentService:
DEFAULT_RULES = {
@@ -488,7 +518,8 @@ class DocumentService:
'score_threshold_enabled': False
}

dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get(
'retrieval_model') else default_retrieval_model

documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))

読み込み中…
キャンセル
保存