| HOSTED_OPENAI_API_ORGANIZATION= | HOSTED_OPENAI_API_ORGANIZATION= | ||||
| HOSTED_OPENAI_QUOTA_LIMIT=200 | HOSTED_OPENAI_QUOTA_LIMIT=200 | ||||
| HOSTED_OPENAI_PAID_ENABLED=false | HOSTED_OPENAI_PAID_ENABLED=false | ||||
| HOSTED_OPENAI_PAID_STRIPE_PRICE_ID= | |||||
| HOSTED_OPENAI_PAID_INCREASE_QUOTA=1 | |||||
| HOSTED_AZURE_OPENAI_ENABLED=false | HOSTED_AZURE_OPENAI_ENABLED=false | ||||
| HOSTED_AZURE_OPENAI_API_KEY= | HOSTED_AZURE_OPENAI_API_KEY= | ||||
| HOSTED_ANTHROPIC_API_KEY= | HOSTED_ANTHROPIC_API_KEY= | ||||
| HOSTED_ANTHROPIC_QUOTA_LIMIT=600000 | HOSTED_ANTHROPIC_QUOTA_LIMIT=600000 | ||||
| HOSTED_ANTHROPIC_PAID_ENABLED=false | HOSTED_ANTHROPIC_PAID_ENABLED=false | ||||
| HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID= | |||||
| HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000 | |||||
| HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20 | |||||
| HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100 | |||||
| # Stripe configuration | |||||
| STRIPE_API_KEY= | |||||
| STRIPE_WEBHOOK_SECRET= | |||||
| # Billing configuration | # Billing configuration | ||||
| BILLING_API_URL=http://127.0.0.1:8000/v1 | BILLING_API_URL=http://127.0.0.1:8000/v1 | ||||
| BILLING_API_SECRET_KEY= | |||||
| STRIPE_WEBHOOK_BILLING_SECRET= | |||||
| BILLING_API_SECRET_KEY= |
| from core.model_providers.providers import hosted | from core.model_providers.providers import hosted | ||||
| from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ | from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ | ||||
| ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension | |||||
| ext_database, ext_storage, ext_mail, ext_code_based_extension | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_login import login_manager | from extensions.ext_login import login_manager | ||||
| ext_login.init_app(app) | ext_login.init_app(app) | ||||
| ext_mail.init_app(app) | ext_mail.init_app(app) | ||||
| ext_sentry.init_app(app) | ext_sentry.init_app(app) | ||||
| ext_stripe.init_app(app) | |||||
| # Flask-Login configuration | # Flask-Login configuration |
| # -*- coding:utf-8 -*- | # -*- coding:utf-8 -*- | ||||
| import os | import os | ||||
| from datetime import timedelta | |||||
| import dotenv | import dotenv | ||||
| from extensions.ext_database import db | |||||
| from extensions.ext_redis import redis_client | |||||
| dotenv.load_dotenv() | dotenv.load_dotenv() | ||||
| 'HOSTED_OPENAI_QUOTA_LIMIT': 200, | 'HOSTED_OPENAI_QUOTA_LIMIT': 200, | ||||
| 'HOSTED_OPENAI_ENABLED': 'False', | 'HOSTED_OPENAI_ENABLED': 'False', | ||||
| 'HOSTED_OPENAI_PAID_ENABLED': 'False', | 'HOSTED_OPENAI_PAID_ENABLED': 'False', | ||||
| 'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1, | |||||
| 'HOSTED_AZURE_OPENAI_ENABLED': 'False', | 'HOSTED_AZURE_OPENAI_ENABLED': 'False', | ||||
| 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, | 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, | ||||
| 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000, | 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000, | ||||
| 'HOSTED_ANTHROPIC_ENABLED': 'False', | 'HOSTED_ANTHROPIC_ENABLED': 'False', | ||||
| 'HOSTED_ANTHROPIC_PAID_ENABLED': 'False', | 'HOSTED_ANTHROPIC_PAID_ENABLED': 'False', | ||||
| 'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000, | |||||
| 'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20, | |||||
| 'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100, | |||||
| 'HOSTED_MODERATION_ENABLED': 'False', | 'HOSTED_MODERATION_ENABLED': 'False', | ||||
| 'HOSTED_MODERATION_PROVIDERS': '', | 'HOSTED_MODERATION_PROVIDERS': '', | ||||
| 'CLEAN_DAY_SETTING': 30, | 'CLEAN_DAY_SETTING': 30, | ||||
| self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION') | self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION') | ||||
| self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT')) | self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT')) | ||||
| self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') | self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') | ||||
| self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID') | |||||
| self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA')) | |||||
| self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') | self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') | ||||
| self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') | self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') | ||||
| self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY') | self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY') | ||||
| self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')) | self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')) | ||||
| self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED') | self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED') | ||||
| self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID') | |||||
| self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')) | |||||
| self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY')) | |||||
| self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY')) | |||||
| self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED') | self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED') | ||||
| self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS') | self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS') | ||||
| self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID') | self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID') | ||||
| self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET') | self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET') | ||||
| self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH') | self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH') | ||||
| self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') | |||||
| self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') |
| # Import universal chat controllers | # Import universal chat controllers | ||||
| from .universal_chat import chat, conversation, message, parameter, audio | from .universal_chat import chat, conversation, message, parameter, audio | ||||
| # Import webhook controllers | |||||
| from .webhook import stripe | |||||
| from .billing import billing | from .billing import billing |
| import stripe | |||||
| import os | |||||
| from flask_restful import Resource, reqparse | from flask_restful import Resource, reqparse | ||||
| from flask_login import current_user | from flask_login import current_user | ||||
| from flask import current_app, request | |||||
| from flask import current_app | |||||
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) | parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id) | |||||
| return BillingService.get_subscription(args['plan'], | |||||
| args['interval'], | |||||
| current_user.email, | |||||
| current_user.current_tenant_id) | |||||
| class Invoices(Resource): | class Invoices(Resource): | ||||
| return BillingService.get_invoices(current_user.email) | return BillingService.get_invoices(current_user.email) | ||||
| class StripeBillingWebhook(Resource): | |||||
| @setup_required | |||||
| @only_edition_cloud | |||||
| def post(self): | |||||
| payload = request.data | |||||
| sig_header = request.headers.get('STRIPE_SIGNATURE') | |||||
| webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET') | |||||
| try: | |||||
| event = stripe.Webhook.construct_event( | |||||
| payload, sig_header, webhook_secret | |||||
| ) | |||||
| except ValueError as e: | |||||
| # Invalid payload | |||||
| return 'Invalid payload', 400 | |||||
| except stripe.error.SignatureVerificationError as e: | |||||
| # Invalid signature | |||||
| return 'Invalid signature', 400 | |||||
| BillingService.process_event(event) | |||||
| return 'success', 200 | |||||
| api.add_resource(BillingInfo, '/billing/info') | api.add_resource(BillingInfo, '/billing/info') | ||||
| api.add_resource(Subscription, '/billing/subscription') | api.add_resource(Subscription, '/billing/subscription') | ||||
| api.add_resource(Invoices, '/billing/invoices') | api.add_resource(Invoices, '/billing/invoices') | ||||
| api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe') |
| import logging | |||||
| import stripe | |||||
| from flask import request, current_app | |||||
| from flask_restful import Resource | |||||
| from controllers.console import api | |||||
| from controllers.console.setup import setup_required | |||||
| from controllers.console.wraps import only_edition_cloud | |||||
| from services.provider_checkout_service import ProviderCheckoutService | |||||
| class StripeWebhookApi(Resource): | |||||
| @setup_required | |||||
| @only_edition_cloud | |||||
| def post(self): | |||||
| payload = request.data | |||||
| sig_header = request.headers.get('STRIPE_SIGNATURE') | |||||
| webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET') | |||||
| try: | |||||
| event = stripe.Webhook.construct_event( | |||||
| payload, sig_header, webhook_secret | |||||
| ) | |||||
| except ValueError as e: | |||||
| # Invalid payload | |||||
| return 'Invalid payload', 400 | |||||
| except stripe.error.SignatureVerificationError as e: | |||||
| # Invalid signature | |||||
| return 'Invalid signature', 400 | |||||
| # Handle the checkout.session.completed event | |||||
| if event['type'] == 'checkout.session.completed': | |||||
| logging.debug(event['data']['object']['id']) | |||||
| logging.debug(event['data']['object']['amount_subtotal']) | |||||
| logging.debug(event['data']['object']['currency']) | |||||
| logging.debug(event['data']['object']['payment_intent']) | |||||
| logging.debug(event['data']['object']['payment_status']) | |||||
| logging.debug(event['data']['object']['metadata']) | |||||
| session = stripe.checkout.Session.retrieve( | |||||
| event['data']['object']['id'], | |||||
| expand=['line_items'], | |||||
| ) | |||||
| logging.debug(session.line_items['data'][0]['quantity']) | |||||
| # Fulfill the purchase... | |||||
| provider_checkout_service = ProviderCheckoutService() | |||||
| try: | |||||
| provider_checkout_service.fulfill_provider_order(event, session.line_items) | |||||
| except Exception as e: | |||||
| logging.debug(str(e)) | |||||
| return 'success', 200 | |||||
| return 'success', 200 | |||||
| api.add_resource(StripeWebhookApi, '/webhook/stripe') |
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.model_providers.error import LLMBadRequestError | from core.model_providers.error import LLMBadRequestError | ||||
| from core.model_providers.providers.base import CredentialsValidateFailedError | from core.model_providers.providers.base import CredentialsValidateFailedError | ||||
| from services.provider_checkout_service import ProviderCheckoutService | |||||
| from services.provider_service import ProviderService | from services.provider_service import ProviderService | ||||
| from services.billing_service import BillingService | |||||
| class ModelProviderListApi(Resource): | class ModelProviderListApi(Resource): | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, provider_name: str): | def get(self, provider_name: str): | ||||
| provider_service = ProviderCheckoutService() | |||||
| provider_checkout = provider_service.create_checkout( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider_name=provider_name, | |||||
| account=current_user | |||||
| ) | |||||
| if provider_name != 'anthropic': | |||||
| raise ValueError(f'provider name {provider_name} is invalid') | |||||
| return { | |||||
| 'url': provider_checkout.get_checkout_url() | |||||
| } | |||||
| data = BillingService.get_model_provider_payment_link(provider_name=provider_name, | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| account_id=current_user.id) | |||||
| return data | |||||
| class ModelProviderFreeQuotaSubmitApi(Resource): | class ModelProviderFreeQuotaSubmitApi(Resource): |
| return False | return False | ||||
| def get_payment_info(self) -> Optional[dict]: | |||||
| """ | |||||
| get product info if it payable. | |||||
| :return: | |||||
| """ | |||||
| if hosted_model_providers.anthropic \ | |||||
| and hosted_model_providers.anthropic.paid_enabled: | |||||
| return { | |||||
| 'product_id': hosted_model_providers.anthropic.paid_stripe_price_id, | |||||
| 'increase_quota': hosted_model_providers.anthropic.paid_increase_quota, | |||||
| 'min_quantity': hosted_model_providers.anthropic.paid_min_quantity, | |||||
| 'max_quantity': hosted_model_providers.anthropic.paid_max_quantity, | |||||
| } | |||||
| return None | |||||
| @classmethod | @classmethod | ||||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | ||||
| """ | """ |
| ).update({'last_used': datetime.utcnow()}) | ).update({'last_used': datetime.utcnow()}) | ||||
| db.session.commit() | db.session.commit() | ||||
| def get_payment_info(self) -> Optional[dict]: | |||||
| """ | |||||
| get product info if it payable. | |||||
| :return: | |||||
| """ | |||||
| return None | |||||
| def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel: | def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel: | ||||
| """ | """ | ||||
| get provider model. | get provider model. |
| quota_limit: int = 0 | quota_limit: int = 0 | ||||
| """Quota limit for the openai hosted model. -1 means unlimited.""" | """Quota limit for the openai hosted model. -1 means unlimited.""" | ||||
| paid_enabled: bool = False | paid_enabled: bool = False | ||||
| paid_stripe_price_id: str = None | |||||
| paid_increase_quota: int = 1 | |||||
| class HostedAzureOpenAI(BaseModel): | class HostedAzureOpenAI(BaseModel): | ||||
| quota_limit: int = 0 | quota_limit: int = 0 | ||||
| """Quota limit for the anthropic hosted model. -1 means unlimited.""" | """Quota limit for the anthropic hosted model. -1 means unlimited.""" | ||||
| paid_enabled: bool = False | paid_enabled: bool = False | ||||
| paid_stripe_price_id: str = None | |||||
| paid_increase_quota: int = 1000000 | |||||
| paid_min_quantity: int = 20 | |||||
| paid_max_quantity: int = 100 | |||||
| class HostedModelProviders(BaseModel): | class HostedModelProviders(BaseModel): | ||||
| api_key=app.config.get("HOSTED_OPENAI_API_KEY"), | api_key=app.config.get("HOSTED_OPENAI_API_KEY"), | ||||
| quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"), | quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"), | ||||
| paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"), | paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"), | ||||
| paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"), | |||||
| paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"), | |||||
| ) | ) | ||||
| if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"): | if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"): | ||||
| api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"), | api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"), | ||||
| quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"), | quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"), | ||||
| paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"), | paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"), | ||||
| paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"), | |||||
| paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"), | |||||
| paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"), | |||||
| paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"), | |||||
| ) | ) | ||||
| if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"): | if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"): |
| return False | return False | ||||
| def get_payment_info(self) -> Optional[dict]: | |||||
| """ | |||||
| get payment info if it payable. | |||||
| :return: | |||||
| """ | |||||
| if hosted_model_providers.openai \ | |||||
| and hosted_model_providers.openai.paid_enabled: | |||||
| return { | |||||
| 'product_id': hosted_model_providers.openai.paid_stripe_price_id, | |||||
| 'increase_quota': hosted_model_providers.openai.paid_increase_quota, | |||||
| } | |||||
| return None | |||||
| @classmethod | @classmethod | ||||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | ||||
| """ | """ |
| import stripe | |||||
| def init_app(app): | |||||
| if app.config.get('STRIPE_API_KEY'): | |||||
| stripe.api_key = app.config.get('STRIPE_API_KEY') |
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | ||||
| class ProviderOrderPaymentStatus(Enum): | |||||
| WAIT_PAY = 'wait_pay' | |||||
| PAID = 'paid' | |||||
| PAY_FAILED = 'pay_failed' | |||||
| REFUNDED = 'refunded' | |||||
| @staticmethod | |||||
| def value_of(value): | |||||
| for member in ProviderOrderPaymentStatus: | |||||
| if member.value == value: | |||||
| return member | |||||
| raise ValueError(f"No matching enum found for value '{value}'") | |||||
| class ProviderOrder(db.Model): | class ProviderOrder(db.Model): | ||||
| __tablename__ = 'provider_orders' | __tablename__ = 'provider_orders' | ||||
| __table_args__ = ( | __table_args__ = ( |
| dashscope~=1.11.0 | dashscope~=1.11.0 | ||||
| huggingface_hub~=0.16.4 | huggingface_hub~=0.16.4 | ||||
| transformers~=4.31.0 | transformers~=4.31.0 | ||||
| stripe~=5.5.0 | |||||
| pandas==1.5.3 | pandas==1.5.3 | ||||
| xinference-client~=0.6.4 | xinference-client~=0.6.4 | ||||
| safetensors==0.3.2 | safetensors==0.3.2 |
| def get_info(cls, tenant_id: str): | def get_info(cls, tenant_id: str): | ||||
| params = {'tenant_id': tenant_id} | params = {'tenant_id': tenant_id} | ||||
| billing_info = cls._send_request('GET', '/info', params=params) | |||||
| billing_info = cls._send_request('GET', '/subscription/info', params=params) | |||||
| return billing_info | return billing_info | ||||
| def get_subscription(cls, plan: str, | def get_subscription(cls, plan: str, | ||||
| interval: str, | interval: str, | ||||
| prefilled_email: str = '', | prefilled_email: str = '', | ||||
| user_name: str = '', | |||||
| tenant_id: str = ''): | tenant_id: str = ''): | ||||
| params = { | params = { | ||||
| 'plan': plan, | 'plan': plan, | ||||
| 'interval': interval, | 'interval': interval, | ||||
| 'prefilled_email': prefilled_email, | 'prefilled_email': prefilled_email, | ||||
| 'user_name': user_name, | |||||
| 'tenant_id': tenant_id | 'tenant_id': tenant_id | ||||
| } | } | ||||
| return cls._send_request('GET', '/subscription', params=params) | |||||
| return cls._send_request('GET', '/subscription/payment-link', params=params) | |||||
| @classmethod | |||||
| def get_model_provider_payment_link(cls, | |||||
| provider_name: str, | |||||
| tenant_id: str, | |||||
| account_id: str): | |||||
| params = { | |||||
| 'provider_name': provider_name, | |||||
| 'tenant_id': tenant_id, | |||||
| 'account_id': account_id | |||||
| } | |||||
| return cls._send_request('GET', '/model-provider/payment-link', params=params) | |||||
| @classmethod | @classmethod | ||||
| def get_invoices(cls, prefilled_email: str = ''): | def get_invoices(cls, prefilled_email: str = ''): | ||||
| response = requests.request(method, url, json=json, params=params, headers=headers) | response = requests.request(method, url, json=json, params=params, headers=headers) | ||||
| return response.json() | return response.json() | ||||
| @classmethod | |||||
| def process_event(cls, event: dict): | |||||
| json = { | |||||
| "content": event, | |||||
| } | |||||
| return cls._send_request('POST', '/webhook/stripe', json=json) |
| import datetime | |||||
| import logging | |||||
| import stripe | |||||
| from flask import current_app | |||||
| from core.model_providers.model_provider_factory import ModelProviderFactory | |||||
| from extensions.ext_database import db | |||||
| from models.account import Account | |||||
| from models.provider import ProviderOrder, ProviderOrderPaymentStatus, ProviderType, Provider, ProviderQuotaType | |||||
| class ProviderCheckout: | |||||
| def __init__(self, stripe_checkout_session): | |||||
| self.stripe_checkout_session = stripe_checkout_session | |||||
| def get_checkout_url(self): | |||||
| return self.stripe_checkout_session.url | |||||
| class ProviderCheckoutService: | |||||
| def create_checkout(self, tenant_id: str, provider_name: str, account: Account) -> ProviderCheckout: | |||||
| # check provider name is valid | |||||
| model_provider_rules = ModelProviderFactory.get_provider_rules() | |||||
| if provider_name not in model_provider_rules: | |||||
| raise ValueError(f'provider name {provider_name} is invalid') | |||||
| model_provider_rule = model_provider_rules[provider_name] | |||||
| # check provider name can be paid | |||||
| self._check_provider_payable(provider_name, model_provider_rule) | |||||
| # get stripe checkout product id | |||||
| paid_provider = self._get_paid_provider(tenant_id, provider_name) | |||||
| model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name) | |||||
| model_provider = model_provider_class(provider=paid_provider) | |||||
| payment_info = model_provider.get_payment_info() | |||||
| if not payment_info: | |||||
| raise ValueError(f'provider name {provider_name} not support payment') | |||||
| payment_product_id = payment_info['product_id'] | |||||
| payment_min_quantity = payment_info['min_quantity'] | |||||
| payment_max_quantity = payment_info['max_quantity'] | |||||
| # create provider order | |||||
| provider_order = ProviderOrder( | |||||
| tenant_id=tenant_id, | |||||
| provider_name=provider_name, | |||||
| account_id=account.id, | |||||
| payment_product_id=payment_product_id, | |||||
| quantity=1, | |||||
| payment_status=ProviderOrderPaymentStatus.WAIT_PAY.value | |||||
| ) | |||||
| db.session.add(provider_order) | |||||
| db.session.flush() | |||||
| line_item = { | |||||
| 'price': f'{payment_product_id}', | |||||
| 'quantity': payment_min_quantity | |||||
| } | |||||
| if payment_min_quantity > 1 and payment_max_quantity != payment_min_quantity: | |||||
| line_item['adjustable_quantity'] = { | |||||
| 'enabled': True, | |||||
| 'minimum': payment_min_quantity, | |||||
| 'maximum': payment_max_quantity | |||||
| } | |||||
| try: | |||||
| # create stripe checkout session | |||||
| checkout_session = stripe.checkout.Session.create( | |||||
| line_items=[ | |||||
| line_item | |||||
| ], | |||||
| mode='payment', | |||||
| success_url=current_app.config.get("CONSOLE_WEB_URL") | |||||
| + f'?provider_name={provider_name}&payment_result=succeeded', | |||||
| cancel_url=current_app.config.get("CONSOLE_WEB_URL") | |||||
| + f'?provider_name={provider_name}&payment_result=cancelled', | |||||
| automatic_tax={'enabled': True}, | |||||
| ) | |||||
| except Exception as e: | |||||
| logging.exception(e) | |||||
| raise ValueError(f'provider name {provider_name} create checkout session failed, please try again later') | |||||
| provider_order.payment_id = checkout_session.id | |||||
| db.session.commit() | |||||
| return ProviderCheckout(checkout_session) | |||||
| def fulfill_provider_order(self, event, line_items): | |||||
| provider_order = db.session.query(ProviderOrder) \ | |||||
| .filter(ProviderOrder.payment_id == event['data']['object']['id']) \ | |||||
| .first() | |||||
| if not provider_order: | |||||
| raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}') | |||||
| if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value: | |||||
| raise ValueError( | |||||
| f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}') | |||||
| provider_order.transaction_id = event['data']['object']['payment_intent'] | |||||
| provider_order.currency = event['data']['object']['currency'] | |||||
| provider_order.total_amount = event['data']['object']['amount_subtotal'] | |||||
| provider_order.payment_status = ProviderOrderPaymentStatus.PAID.value | |||||
| provider_order.paid_at = datetime.datetime.utcnow() | |||||
| provider_order.updated_at = provider_order.paid_at | |||||
| # update provider quota | |||||
| provider = db.session.query(Provider).filter( | |||||
| Provider.tenant_id == provider_order.tenant_id, | |||||
| Provider.provider_name == provider_order.provider_name, | |||||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||||
| Provider.quota_type == ProviderQuotaType.PAID.value | |||||
| ).first() | |||||
| if not provider: | |||||
| raise ValueError(f'provider not found, tenant id: {provider_order.tenant_id}, ' | |||||
| f'provider name: {provider_order.provider_name}') | |||||
| model_provider_class = ModelProviderFactory.get_model_provider_class(provider_order.provider_name) | |||||
| model_provider = model_provider_class(provider=provider) | |||||
| payment_info = model_provider.get_payment_info() | |||||
| quantity = line_items['data'][0]['quantity'] | |||||
| if not payment_info: | |||||
| increase_quota = 0 | |||||
| else: | |||||
| increase_quota = int(payment_info['increase_quota']) * quantity | |||||
| if increase_quota > 0: | |||||
| provider.quota_limit += increase_quota | |||||
| provider.is_valid = True | |||||
| db.session.commit() | |||||
| def _check_provider_payable(self, provider_name: str, model_provider_rule: dict): | |||||
| if ProviderType.SYSTEM.value not in model_provider_rule['support_provider_types']: | |||||
| raise ValueError(f'provider name {provider_name} not support payment') | |||||
| if 'system_config' not in model_provider_rule: | |||||
| raise ValueError(f'provider name {provider_name} not support payment') | |||||
| if 'supported_quota_types' not in model_provider_rule['system_config']: | |||||
| raise ValueError(f'provider name {provider_name} not support payment') | |||||
| if 'paid' not in model_provider_rule['system_config']['supported_quota_types']: | |||||
| raise ValueError(f'provider name {provider_name} not support payment') | |||||
| def _get_paid_provider(self, tenant_id: str, provider_name: str): | |||||
| paid_provider = db.session.query(Provider) \ | |||||
| .filter( | |||||
| Provider.tenant_id == tenant_id, | |||||
| Provider.provider_name == provider_name, | |||||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||||
| Provider.quota_type == ProviderQuotaType.PAID.value, | |||||
| ).first() | |||||
| if not paid_provider: | |||||
| paid_provider = Provider( | |||||
| tenant_id=tenant_id, | |||||
| provider_name=provider_name, | |||||
| provider_type=ProviderType.SYSTEM.value, | |||||
| quota_type=ProviderQuotaType.PAID.value, | |||||
| quota_limit=0, | |||||
| quota_used=0, | |||||
| ) | |||||
| db.session.add(paid_provider) | |||||
| db.session.commit() | |||||
| return paid_provider |