| @@ -106,8 +106,6 @@ HOSTED_OPENAI_API_BASE= | |||
| HOSTED_OPENAI_API_ORGANIZATION= | |||
| HOSTED_OPENAI_QUOTA_LIMIT=200 | |||
| HOSTED_OPENAI_PAID_ENABLED=false | |||
| HOSTED_OPENAI_PAID_STRIPE_PRICE_ID= | |||
| HOSTED_OPENAI_PAID_INCREASE_QUOTA=1 | |||
| HOSTED_AZURE_OPENAI_ENABLED=false | |||
| HOSTED_AZURE_OPENAI_API_KEY= | |||
| @@ -119,16 +117,7 @@ HOSTED_ANTHROPIC_API_BASE= | |||
| HOSTED_ANTHROPIC_API_KEY= | |||
| HOSTED_ANTHROPIC_QUOTA_LIMIT=600000 | |||
| HOSTED_ANTHROPIC_PAID_ENABLED=false | |||
| HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID= | |||
| HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000 | |||
| HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20 | |||
| HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100 | |||
| # Stripe configuration | |||
| STRIPE_API_KEY= | |||
| STRIPE_WEBHOOK_SECRET= | |||
| # Billing configuration | |||
| BILLING_API_URL=http://127.0.0.1:8000/v1 | |||
| BILLING_API_SECRET_KEY= | |||
| STRIPE_WEBHOOK_BILLING_SECRET= | |||
| BILLING_API_SECRET_KEY= | |||
| @@ -20,7 +20,7 @@ from flask_cors import CORS | |||
| from core.model_providers.providers import hosted | |||
| from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ | |||
| ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension | |||
| ext_database, ext_storage, ext_mail, ext_code_based_extension | |||
| from extensions.ext_database import db | |||
| from extensions.ext_login import login_manager | |||
| @@ -96,7 +96,6 @@ def initialize_extensions(app): | |||
| ext_login.init_app(app) | |||
| ext_mail.init_app(app) | |||
| ext_sentry.init_app(app) | |||
| ext_stripe.init_app(app) | |||
| # Flask-Login configuration | |||
| @@ -1,11 +1,8 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import os | |||
| from datetime import timedelta | |||
| import dotenv | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| dotenv.load_dotenv() | |||
| @@ -44,15 +41,11 @@ DEFAULTS = { | |||
| 'HOSTED_OPENAI_QUOTA_LIMIT': 200, | |||
| 'HOSTED_OPENAI_ENABLED': 'False', | |||
| 'HOSTED_OPENAI_PAID_ENABLED': 'False', | |||
| 'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1, | |||
| 'HOSTED_AZURE_OPENAI_ENABLED': 'False', | |||
| 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, | |||
| 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000, | |||
| 'HOSTED_ANTHROPIC_ENABLED': 'False', | |||
| 'HOSTED_ANTHROPIC_PAID_ENABLED': 'False', | |||
| 'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000, | |||
| 'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20, | |||
| 'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100, | |||
| 'HOSTED_MODERATION_ENABLED': 'False', | |||
| 'HOSTED_MODERATION_PROVIDERS': '', | |||
| 'CLEAN_DAY_SETTING': 30, | |||
| @@ -268,8 +261,6 @@ class Config: | |||
| self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION') | |||
| self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT')) | |||
| self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') | |||
| self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID') | |||
| self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA')) | |||
| self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') | |||
| self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') | |||
| @@ -281,10 +272,6 @@ class Config: | |||
| self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY') | |||
| self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')) | |||
| self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED') | |||
| self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID') | |||
| self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')) | |||
| self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY')) | |||
| self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY')) | |||
| self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED') | |||
| self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS') | |||
| @@ -302,6 +289,3 @@ class CloudEditionConfig(Config): | |||
| self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID') | |||
| self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET') | |||
| self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH') | |||
| self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') | |||
| self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') | |||
| @@ -26,7 +26,4 @@ from .explore import installed_app, recommended_app, completion, conversation, m | |||
| # Import universal chat controllers | |||
| from .universal_chat import chat, conversation, message, parameter, audio | |||
| # Import webhook controllers | |||
| from .webhook import stripe | |||
| from .billing import billing | |||
| @@ -1,9 +1,6 @@ | |||
| import stripe | |||
| import os | |||
| from flask_restful import Resource, reqparse | |||
| from flask_login import current_user | |||
| from flask import current_app, request | |||
| from flask import current_app | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| @@ -40,7 +37,10 @@ class Subscription(Resource): | |||
| parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) | |||
| args = parser.parse_args() | |||
| return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id) | |||
| return BillingService.get_subscription(args['plan'], | |||
| args['interval'], | |||
| current_user.email, | |||
| current_user.current_tenant_id) | |||
| class Invoices(Resource): | |||
| @@ -54,32 +54,6 @@ class Invoices(Resource): | |||
| return BillingService.get_invoices(current_user.email) | |||
| class StripeBillingWebhook(Resource): | |||
| @setup_required | |||
| @only_edition_cloud | |||
| def post(self): | |||
| payload = request.data | |||
| sig_header = request.headers.get('STRIPE_SIGNATURE') | |||
| webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET') | |||
| try: | |||
| event = stripe.Webhook.construct_event( | |||
| payload, sig_header, webhook_secret | |||
| ) | |||
| except ValueError as e: | |||
| # Invalid payload | |||
| return 'Invalid payload', 400 | |||
| except stripe.error.SignatureVerificationError as e: | |||
| # Invalid signature | |||
| return 'Invalid signature', 400 | |||
| BillingService.process_event(event) | |||
| return 'success', 200 | |||
| api.add_resource(BillingInfo, '/billing/info') | |||
| api.add_resource(Subscription, '/billing/subscription') | |||
| api.add_resource(Invoices, '/billing/invoices') | |||
| api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe') | |||
| @@ -1,61 +0,0 @@ | |||
| import logging | |||
| import stripe | |||
| from flask import request, current_app | |||
| from flask_restful import Resource | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import only_edition_cloud | |||
| from services.provider_checkout_service import ProviderCheckoutService | |||
| class StripeWebhookApi(Resource): | |||
| @setup_required | |||
| @only_edition_cloud | |||
| def post(self): | |||
| payload = request.data | |||
| sig_header = request.headers.get('STRIPE_SIGNATURE') | |||
| webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET') | |||
| try: | |||
| event = stripe.Webhook.construct_event( | |||
| payload, sig_header, webhook_secret | |||
| ) | |||
| except ValueError as e: | |||
| # Invalid payload | |||
| return 'Invalid payload', 400 | |||
| except stripe.error.SignatureVerificationError as e: | |||
| # Invalid signature | |||
| return 'Invalid signature', 400 | |||
| # Handle the checkout.session.completed event | |||
| if event['type'] == 'checkout.session.completed': | |||
| logging.debug(event['data']['object']['id']) | |||
| logging.debug(event['data']['object']['amount_subtotal']) | |||
| logging.debug(event['data']['object']['currency']) | |||
| logging.debug(event['data']['object']['payment_intent']) | |||
| logging.debug(event['data']['object']['payment_status']) | |||
| logging.debug(event['data']['object']['metadata']) | |||
| session = stripe.checkout.Session.retrieve( | |||
| event['data']['object']['id'], | |||
| expand=['line_items'], | |||
| ) | |||
| logging.debug(session.line_items['data'][0]['quantity']) | |||
| # Fulfill the purchase... | |||
| provider_checkout_service = ProviderCheckoutService() | |||
| try: | |||
| provider_checkout_service.fulfill_provider_order(event, session.line_items) | |||
| except Exception as e: | |||
| logging.debug(str(e)) | |||
| return 'success', 200 | |||
| return 'success', 200 | |||
| api.add_resource(StripeWebhookApi, '/webhook/stripe') | |||
| @@ -9,8 +9,8 @@ from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||
| from services.provider_checkout_service import ProviderCheckoutService | |||
| from services.provider_service import ProviderService | |||
| from services.billing_service import BillingService | |||
| class ModelProviderListApi(Resource): | |||
| @@ -264,16 +264,13 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider_name: str): | |||
| provider_service = ProviderCheckoutService() | |||
| provider_checkout = provider_service.create_checkout( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name, | |||
| account=current_user | |||
| ) | |||
| if provider_name != 'anthropic': | |||
| raise ValueError(f'provider name {provider_name} is invalid') | |||
| return { | |||
| 'url': provider_checkout.get_checkout_url() | |||
| } | |||
| data = BillingService.get_model_provider_payment_link(provider_name=provider_name, | |||
| tenant_id=current_user.current_tenant_id, | |||
| account_id=current_user.id) | |||
| return data | |||
| class ModelProviderFreeQuotaSubmitApi(Resource): | |||
| @@ -191,23 +191,6 @@ class AnthropicProvider(BaseModelProvider): | |||
| return False | |||
| def get_payment_info(self) -> Optional[dict]: | |||
| """ | |||
| get product info if it payable. | |||
| :return: | |||
| """ | |||
| if hosted_model_providers.anthropic \ | |||
| and hosted_model_providers.anthropic.paid_enabled: | |||
| return { | |||
| 'product_id': hosted_model_providers.anthropic.paid_stripe_price_id, | |||
| 'increase_quota': hosted_model_providers.anthropic.paid_increase_quota, | |||
| 'min_quantity': hosted_model_providers.anthropic.paid_min_quantity, | |||
| 'max_quantity': hosted_model_providers.anthropic.paid_max_quantity, | |||
| } | |||
| return None | |||
| @classmethod | |||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | |||
| """ | |||
| @@ -267,14 +267,6 @@ class BaseModelProvider(BaseModel, ABC): | |||
| ).update({'last_used': datetime.utcnow()}) | |||
| db.session.commit() | |||
| def get_payment_info(self) -> Optional[dict]: | |||
| """ | |||
| get product info if it payable. | |||
| :return: | |||
| """ | |||
| return None | |||
| def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel: | |||
| """ | |||
| get provider model. | |||
| @@ -13,8 +13,6 @@ class HostedOpenAI(BaseModel): | |||
| quota_limit: int = 0 | |||
| """Quota limit for the openai hosted model. -1 means unlimited.""" | |||
| paid_enabled: bool = False | |||
| paid_stripe_price_id: str = None | |||
| paid_increase_quota: int = 1 | |||
| class HostedAzureOpenAI(BaseModel): | |||
| @@ -30,10 +28,6 @@ class HostedAnthropic(BaseModel): | |||
| quota_limit: int = 0 | |||
| """Quota limit for the anthropic hosted model. -1 means unlimited.""" | |||
| paid_enabled: bool = False | |||
| paid_stripe_price_id: str = None | |||
| paid_increase_quota: int = 1000000 | |||
| paid_min_quantity: int = 20 | |||
| paid_max_quantity: int = 100 | |||
| class HostedModelProviders(BaseModel): | |||
| @@ -68,8 +62,6 @@ def init_app(app: Flask): | |||
| api_key=app.config.get("HOSTED_OPENAI_API_KEY"), | |||
| quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"), | |||
| paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"), | |||
| paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"), | |||
| paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"), | |||
| ) | |||
| if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"): | |||
| @@ -85,10 +77,6 @@ def init_app(app: Flask): | |||
| api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"), | |||
| quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"), | |||
| paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"), | |||
| paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"), | |||
| paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"), | |||
| paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"), | |||
| paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"), | |||
| ) | |||
| if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"): | |||
| @@ -282,21 +282,6 @@ class OpenAIProvider(BaseModelProvider): | |||
| return False | |||
| def get_payment_info(self) -> Optional[dict]: | |||
| """ | |||
| get payment info if it payable. | |||
| :return: | |||
| """ | |||
| if hosted_model_providers.openai \ | |||
| and hosted_model_providers.openai.paid_enabled: | |||
| return { | |||
| 'product_id': hosted_model_providers.openai.paid_stripe_price_id, | |||
| 'increase_quota': hosted_model_providers.openai.paid_increase_quota, | |||
| } | |||
| return None | |||
| @classmethod | |||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | |||
| """ | |||
| @@ -1,6 +0,0 @@ | |||
| import stripe | |||
| def init_app(app): | |||
| if app.config.get('STRIPE_API_KEY'): | |||
| stripe.api_key = app.config.get('STRIPE_API_KEY') | |||
| @@ -135,21 +135,6 @@ class TenantPreferredModelProvider(db.Model): | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| class ProviderOrderPaymentStatus(Enum): | |||
| WAIT_PAY = 'wait_pay' | |||
| PAID = 'paid' | |||
| PAY_FAILED = 'pay_failed' | |||
| REFUNDED = 'refunded' | |||
| @staticmethod | |||
| def value_of(value): | |||
| for member in ProviderOrderPaymentStatus: | |||
| if member.value == value: | |||
| return member | |||
| raise ValueError(f"No matching enum found for value '{value}'") | |||
| class ProviderOrder(db.Model): | |||
| __tablename__ = 'provider_orders' | |||
| __table_args__ = ( | |||
| @@ -46,7 +46,6 @@ websocket-client~=1.6.1 | |||
| dashscope~=1.11.0 | |||
| huggingface_hub~=0.16.4 | |||
| transformers~=4.31.0 | |||
| stripe~=5.5.0 | |||
| pandas==1.5.3 | |||
| xinference-client~=0.6.4 | |||
| safetensors==0.3.2 | |||
| @@ -10,7 +10,7 @@ class BillingService: | |||
| def get_info(cls, tenant_id: str): | |||
| params = {'tenant_id': tenant_id} | |||
| billing_info = cls._send_request('GET', '/info', params=params) | |||
| billing_info = cls._send_request('GET', '/subscription/info', params=params) | |||
| return billing_info | |||
| @@ -18,16 +18,26 @@ class BillingService: | |||
| def get_subscription(cls, plan: str, | |||
| interval: str, | |||
| prefilled_email: str = '', | |||
| user_name: str = '', | |||
| tenant_id: str = ''): | |||
| params = { | |||
| 'plan': plan, | |||
| 'interval': interval, | |||
| 'prefilled_email': prefilled_email, | |||
| 'user_name': user_name, | |||
| 'tenant_id': tenant_id | |||
| } | |||
| return cls._send_request('GET', '/subscription', params=params) | |||
| return cls._send_request('GET', '/subscription/payment-link', params=params) | |||
| @classmethod | |||
| def get_model_provider_payment_link(cls, | |||
| provider_name: str, | |||
| tenant_id: str, | |||
| account_id: str): | |||
| params = { | |||
| 'provider_name': provider_name, | |||
| 'tenant_id': tenant_id, | |||
| 'account_id': account_id | |||
| } | |||
| return cls._send_request('GET', '/model-provider/payment-link', params=params) | |||
| @classmethod | |||
| def get_invoices(cls, prefilled_email: str = ''): | |||
| @@ -45,10 +55,3 @@ class BillingService: | |||
| response = requests.request(method, url, json=json, params=params, headers=headers) | |||
| return response.json() | |||
| @classmethod | |||
| def process_event(cls, event: dict): | |||
| json = { | |||
| "content": event, | |||
| } | |||
| return cls._send_request('POST', '/webhook/stripe', json=json) | |||
| @@ -1,174 +0,0 @@ | |||
| import datetime | |||
| import logging | |||
| import stripe | |||
| from flask import current_app | |||
| from core.model_providers.model_provider_factory import ModelProviderFactory | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.provider import ProviderOrder, ProviderOrderPaymentStatus, ProviderType, Provider, ProviderQuotaType | |||
| class ProviderCheckout: | |||
| def __init__(self, stripe_checkout_session): | |||
| self.stripe_checkout_session = stripe_checkout_session | |||
| def get_checkout_url(self): | |||
| return self.stripe_checkout_session.url | |||
| class ProviderCheckoutService: | |||
| def create_checkout(self, tenant_id: str, provider_name: str, account: Account) -> ProviderCheckout: | |||
| # check provider name is valid | |||
| model_provider_rules = ModelProviderFactory.get_provider_rules() | |||
| if provider_name not in model_provider_rules: | |||
| raise ValueError(f'provider name {provider_name} is invalid') | |||
| model_provider_rule = model_provider_rules[provider_name] | |||
| # check provider name can be paid | |||
| self._check_provider_payable(provider_name, model_provider_rule) | |||
| # get stripe checkout product id | |||
| paid_provider = self._get_paid_provider(tenant_id, provider_name) | |||
| model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name) | |||
| model_provider = model_provider_class(provider=paid_provider) | |||
| payment_info = model_provider.get_payment_info() | |||
| if not payment_info: | |||
| raise ValueError(f'provider name {provider_name} not support payment') | |||
| payment_product_id = payment_info['product_id'] | |||
| payment_min_quantity = payment_info['min_quantity'] | |||
| payment_max_quantity = payment_info['max_quantity'] | |||
| # create provider order | |||
| provider_order = ProviderOrder( | |||
| tenant_id=tenant_id, | |||
| provider_name=provider_name, | |||
| account_id=account.id, | |||
| payment_product_id=payment_product_id, | |||
| quantity=1, | |||
| payment_status=ProviderOrderPaymentStatus.WAIT_PAY.value | |||
| ) | |||
| db.session.add(provider_order) | |||
| db.session.flush() | |||
| line_item = { | |||
| 'price': f'{payment_product_id}', | |||
| 'quantity': payment_min_quantity | |||
| } | |||
| if payment_min_quantity > 1 and payment_max_quantity != payment_min_quantity: | |||
| line_item['adjustable_quantity'] = { | |||
| 'enabled': True, | |||
| 'minimum': payment_min_quantity, | |||
| 'maximum': payment_max_quantity | |||
| } | |||
| try: | |||
| # create stripe checkout session | |||
| checkout_session = stripe.checkout.Session.create( | |||
| line_items=[ | |||
| line_item | |||
| ], | |||
| mode='payment', | |||
| success_url=current_app.config.get("CONSOLE_WEB_URL") | |||
| + f'?provider_name={provider_name}&payment_result=succeeded', | |||
| cancel_url=current_app.config.get("CONSOLE_WEB_URL") | |||
| + f'?provider_name={provider_name}&payment_result=cancelled', | |||
| automatic_tax={'enabled': True}, | |||
| ) | |||
| except Exception as e: | |||
| logging.exception(e) | |||
| raise ValueError(f'provider name {provider_name} create checkout session failed, please try again later') | |||
| provider_order.payment_id = checkout_session.id | |||
| db.session.commit() | |||
| return ProviderCheckout(checkout_session) | |||
| def fulfill_provider_order(self, event, line_items): | |||
| provider_order = db.session.query(ProviderOrder) \ | |||
| .filter(ProviderOrder.payment_id == event['data']['object']['id']) \ | |||
| .first() | |||
| if not provider_order: | |||
| raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}') | |||
| if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value: | |||
| raise ValueError( | |||
| f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}') | |||
| provider_order.transaction_id = event['data']['object']['payment_intent'] | |||
| provider_order.currency = event['data']['object']['currency'] | |||
| provider_order.total_amount = event['data']['object']['amount_subtotal'] | |||
| provider_order.payment_status = ProviderOrderPaymentStatus.PAID.value | |||
| provider_order.paid_at = datetime.datetime.utcnow() | |||
| provider_order.updated_at = provider_order.paid_at | |||
| # update provider quota | |||
| provider = db.session.query(Provider).filter( | |||
| Provider.tenant_id == provider_order.tenant_id, | |||
| Provider.provider_name == provider_order.provider_name, | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| Provider.quota_type == ProviderQuotaType.PAID.value | |||
| ).first() | |||
| if not provider: | |||
| raise ValueError(f'provider not found, tenant id: {provider_order.tenant_id}, ' | |||
| f'provider name: {provider_order.provider_name}') | |||
| model_provider_class = ModelProviderFactory.get_model_provider_class(provider_order.provider_name) | |||
| model_provider = model_provider_class(provider=provider) | |||
| payment_info = model_provider.get_payment_info() | |||
| quantity = line_items['data'][0]['quantity'] | |||
| if not payment_info: | |||
| increase_quota = 0 | |||
| else: | |||
| increase_quota = int(payment_info['increase_quota']) * quantity | |||
| if increase_quota > 0: | |||
| provider.quota_limit += increase_quota | |||
| provider.is_valid = True | |||
| db.session.commit() | |||
| def _check_provider_payable(self, provider_name: str, model_provider_rule: dict): | |||
| if ProviderType.SYSTEM.value not in model_provider_rule['support_provider_types']: | |||
| raise ValueError(f'provider name {provider_name} not support payment') | |||
| if 'system_config' not in model_provider_rule: | |||
| raise ValueError(f'provider name {provider_name} not support payment') | |||
| if 'supported_quota_types' not in model_provider_rule['system_config']: | |||
| raise ValueError(f'provider name {provider_name} not support payment') | |||
| if 'paid' not in model_provider_rule['system_config']['supported_quota_types']: | |||
| raise ValueError(f'provider name {provider_name} not support payment') | |||
| def _get_paid_provider(self, tenant_id: str, provider_name: str): | |||
| paid_provider = db.session.query(Provider) \ | |||
| .filter( | |||
| Provider.tenant_id == tenant_id, | |||
| Provider.provider_name == provider_name, | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| Provider.quota_type == ProviderQuotaType.PAID.value, | |||
| ).first() | |||
| if not paid_provider: | |||
| paid_provider = Provider( | |||
| tenant_id=tenant_id, | |||
| provider_name=provider_name, | |||
| provider_type=ProviderType.SYSTEM.value, | |||
| quota_type=ProviderQuotaType.PAID.value, | |||
| quota_limit=0, | |||
| quota_used=0, | |||
| ) | |||
| db.session.add(paid_provider) | |||
| db.session.commit() | |||
| return paid_provider | |||