ソースを参照

remove stripe and anthropic. (#1746)

tags/0.3.34
Garfield Dai 1年前
コミット
6b499b9a16
コミッターのメールアドレスに関連付けられたアカウントが存在しません

+ 1
- 12
api/.env.example ファイルの表示

@@ -106,8 +106,6 @@ HOSTED_OPENAI_API_BASE=
HOSTED_OPENAI_API_ORGANIZATION=
HOSTED_OPENAI_QUOTA_LIMIT=200
HOSTED_OPENAI_PAID_ENABLED=false
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1

HOSTED_AZURE_OPENAI_ENABLED=false
HOSTED_AZURE_OPENAI_API_KEY=
@@ -119,16 +117,7 @@ HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
HOSTED_ANTHROPIC_PAID_ENABLED=false
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100

# Stripe configuration
STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=

# Billing configuration
BILLING_API_URL=http://127.0.0.1:8000/v1
BILLING_API_SECRET_KEY=
STRIPE_WEBHOOK_BILLING_SECRET=
BILLING_API_SECRET_KEY=

+ 1
- 2
api/app.py ファイルの表示

@@ -20,7 +20,7 @@ from flask_cors import CORS

from core.model_providers.providers import hosted
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension
ext_database, ext_storage, ext_mail, ext_code_based_extension
from extensions.ext_database import db
from extensions.ext_login import login_manager

@@ -96,7 +96,6 @@ def initialize_extensions(app):
ext_login.init_app(app)
ext_mail.init_app(app)
ext_sentry.init_app(app)
ext_stripe.init_app(app)


# Flask-Login configuration

+ 0
- 16
api/config.py ファイルの表示

@@ -1,11 +1,8 @@
# -*- coding:utf-8 -*-
import os
from datetime import timedelta

import dotenv

from extensions.ext_database import db
from extensions.ext_redis import redis_client

dotenv.load_dotenv()

@@ -44,15 +41,11 @@ DEFAULTS = {
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_ENABLED': 'False',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
'HOSTED_ANTHROPIC_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'HOSTED_MODERATION_ENABLED': 'False',
'HOSTED_MODERATION_PROVIDERS': '',
'CLEAN_DAY_SETTING': 30,
@@ -268,8 +261,6 @@ class Config:
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))

self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
@@ -281,10 +272,6 @@ class Config:
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))

self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
@@ -302,6 +289,3 @@ class CloudEditionConfig(Config):
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')

self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

+ 0
- 3
api/controllers/console/__init__.py ファイルの表示

@@ -26,7 +26,4 @@ from .explore import installed_app, recommended_app, completion, conversation, m
# Import universal chat controllers
from .universal_chat import chat, conversation, message, parameter, audio

# Import webhook controllers
from .webhook import stripe

from .billing import billing

+ 5
- 31
api/controllers/console/billing/billing.py ファイルの表示

@@ -1,9 +1,6 @@
import stripe
import os

from flask_restful import Resource, reqparse
from flask_login import current_user
from flask import current_app, request
from flask import current_app

from controllers.console import api
from controllers.console.setup import setup_required
@@ -40,7 +37,10 @@ class Subscription(Resource):
parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
args = parser.parse_args()

return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id)
return BillingService.get_subscription(args['plan'],
args['interval'],
current_user.email,
current_user.current_tenant_id)


class Invoices(Resource):
@@ -54,32 +54,6 @@ class Invoices(Resource):
return BillingService.get_invoices(current_user.email)


class StripeBillingWebhook(Resource):

@setup_required
@only_edition_cloud
def post(self):
payload = request.data
sig_header = request.headers.get('STRIPE_SIGNATURE')
webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET')

try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
# Invalid payload
return 'Invalid payload', 400
except stripe.error.SignatureVerificationError as e:
# Invalid signature
return 'Invalid signature', 400

BillingService.process_event(event)

return 'success', 200


api.add_resource(BillingInfo, '/billing/info')
api.add_resource(Subscription, '/billing/subscription')
api.add_resource(Invoices, '/billing/invoices')
api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe')

+ 0
- 0
api/controllers/console/webhook/__init__.py ファイルの表示


+ 0
- 61
api/controllers/console/webhook/stripe.py ファイルの表示

@@ -1,61 +0,0 @@
import logging

import stripe
from flask import request, current_app
from flask_restful import Resource

from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import only_edition_cloud
from services.provider_checkout_service import ProviderCheckoutService


class StripeWebhookApi(Resource):
@setup_required
@only_edition_cloud
def post(self):
payload = request.data
sig_header = request.headers.get('STRIPE_SIGNATURE')
webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET')

try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError as e:
# Invalid payload
return 'Invalid payload', 400
except stripe.error.SignatureVerificationError as e:
# Invalid signature
return 'Invalid signature', 400

# Handle the checkout.session.completed event
if event['type'] == 'checkout.session.completed':
logging.debug(event['data']['object']['id'])
logging.debug(event['data']['object']['amount_subtotal'])
logging.debug(event['data']['object']['currency'])
logging.debug(event['data']['object']['payment_intent'])
logging.debug(event['data']['object']['payment_status'])
logging.debug(event['data']['object']['metadata'])

session = stripe.checkout.Session.retrieve(
event['data']['object']['id'],
expand=['line_items'],
)

logging.debug(session.line_items['data'][0]['quantity'])

# Fulfill the purchase...
provider_checkout_service = ProviderCheckoutService()

try:
provider_checkout_service.fulfill_provider_order(event, session.line_items)
except Exception as e:

logging.debug(str(e))
return 'success', 200

return 'success', 200


api.add_resource(StripeWebhookApi, '/webhook/stripe')

+ 7
- 10
api/controllers/console/workspace/model_providers.py ファイルの表示

@@ -9,8 +9,8 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import CredentialsValidateFailedError
from services.provider_checkout_service import ProviderCheckoutService
from services.provider_service import ProviderService
from services.billing_service import BillingService


class ModelProviderListApi(Resource):
@@ -264,16 +264,13 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_name: str):
provider_service = ProviderCheckoutService()
provider_checkout = provider_service.create_checkout(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
account=current_user
)
if provider_name != 'anthropic':
raise ValueError(f'provider name {provider_name} is invalid')

return {
'url': provider_checkout.get_checkout_url()
}
data = BillingService.get_model_provider_payment_link(provider_name=provider_name,
tenant_id=current_user.current_tenant_id,
account_id=current_user.id)
return data


class ModelProviderFreeQuotaSubmitApi(Resource):

+ 0
- 17
api/core/model_providers/providers/anthropic_provider.py ファイルの表示

@@ -191,23 +191,6 @@ class AnthropicProvider(BaseModelProvider):

return False

def get_payment_info(self) -> Optional[dict]:
"""
get product info if it payable.

:return:
"""
if hosted_model_providers.anthropic \
and hosted_model_providers.anthropic.paid_enabled:
return {
'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
}

return None

@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""

+ 0
- 8
api/core/model_providers/providers/base.py ファイルの表示

@@ -267,14 +267,6 @@ class BaseModelProvider(BaseModel, ABC):
).update({'last_used': datetime.utcnow()})
db.session.commit()

def get_payment_info(self) -> Optional[dict]:
"""
get product info if it payable.

:return:
"""
return None

def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
"""
get provider model.

+ 0
- 12
api/core/model_providers/providers/hosted.py ファイルの表示

@@ -13,8 +13,6 @@ class HostedOpenAI(BaseModel):
quota_limit: int = 0
"""Quota limit for the openai hosted model. -1 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1


class HostedAzureOpenAI(BaseModel):
@@ -30,10 +28,6 @@ class HostedAnthropic(BaseModel):
quota_limit: int = 0
"""Quota limit for the anthropic hosted model. -1 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1000000
paid_min_quantity: int = 20
paid_max_quantity: int = 100


class HostedModelProviders(BaseModel):
@@ -68,8 +62,6 @@ def init_app(app: Flask):
api_key=app.config.get("HOSTED_OPENAI_API_KEY"),
quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"),
paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"),
)

if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"):
@@ -85,10 +77,6 @@ def init_app(app: Flask):
api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"),
quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"),
paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
)

if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):

+ 0
- 15
api/core/model_providers/providers/openai_provider.py ファイルの表示

@@ -282,21 +282,6 @@ class OpenAIProvider(BaseModelProvider):

return False

def get_payment_info(self) -> Optional[dict]:
"""
get payment info if it payable.

:return:
"""
if hosted_model_providers.openai \
and hosted_model_providers.openai.paid_enabled:
return {
'product_id': hosted_model_providers.openai.paid_stripe_price_id,
'increase_quota': hosted_model_providers.openai.paid_increase_quota,
}

return None

@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""

+ 0
- 6
api/extensions/ext_stripe.py ファイルの表示

@@ -1,6 +0,0 @@
import stripe


def init_app(app):
if app.config.get('STRIPE_API_KEY'):
stripe.api_key = app.config.get('STRIPE_API_KEY')

+ 0
- 15
api/models/provider.py ファイルの表示

@@ -135,21 +135,6 @@ class TenantPreferredModelProvider(db.Model):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))


class ProviderOrderPaymentStatus(Enum):
WAIT_PAY = 'wait_pay'
PAID = 'paid'
PAY_FAILED = 'pay_failed'
REFUNDED = 'refunded'

@staticmethod
def value_of(value):
for member in ProviderOrderPaymentStatus:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")



class ProviderOrder(db.Model):
__tablename__ = 'provider_orders'
__table_args__ = (

+ 0
- 1
api/requirements.txt ファイルの表示

@@ -46,7 +46,6 @@ websocket-client~=1.6.1
dashscope~=1.11.0
huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference-client~=0.6.4
safetensors==0.3.2

+ 14
- 11
api/services/billing_service.py ファイルの表示

@@ -10,7 +10,7 @@ class BillingService:
def get_info(cls, tenant_id: str):
params = {'tenant_id': tenant_id}

billing_info = cls._send_request('GET', '/info', params=params)
billing_info = cls._send_request('GET', '/subscription/info', params=params)

return billing_info

@@ -18,16 +18,26 @@ class BillingService:
def get_subscription(cls, plan: str,
interval: str,
prefilled_email: str = '',
user_name: str = '',
tenant_id: str = ''):
params = {
'plan': plan,
'interval': interval,
'prefilled_email': prefilled_email,
'user_name': user_name,
'tenant_id': tenant_id
}
return cls._send_request('GET', '/subscription', params=params)
return cls._send_request('GET', '/subscription/payment-link', params=params)

@classmethod
def get_model_provider_payment_link(cls,
provider_name: str,
tenant_id: str,
account_id: str):
params = {
'provider_name': provider_name,
'tenant_id': tenant_id,
'account_id': account_id
}
return cls._send_request('GET', '/model-provider/payment-link', params=params)

@classmethod
def get_invoices(cls, prefilled_email: str = ''):
@@ -45,10 +55,3 @@ class BillingService:
response = requests.request(method, url, json=json, params=params, headers=headers)

return response.json()

@classmethod
def process_event(cls, event: dict):
json = {
"content": event,
}
return cls._send_request('POST', '/webhook/stripe', json=json)

+ 0
- 174
api/services/provider_checkout_service.py ファイルの表示

@@ -1,174 +0,0 @@
import datetime
import logging

import stripe
from flask import current_app

from core.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db
from models.account import Account
from models.provider import ProviderOrder, ProviderOrderPaymentStatus, ProviderType, Provider, ProviderQuotaType


class ProviderCheckout:
def __init__(self, stripe_checkout_session):
self.stripe_checkout_session = stripe_checkout_session

def get_checkout_url(self):
return self.stripe_checkout_session.url


class ProviderCheckoutService:
def create_checkout(self, tenant_id: str, provider_name: str, account: Account) -> ProviderCheckout:
# check provider name is valid
model_provider_rules = ModelProviderFactory.get_provider_rules()
if provider_name not in model_provider_rules:
raise ValueError(f'provider name {provider_name} is invalid')

model_provider_rule = model_provider_rules[provider_name]

# check provider name can be paid
self._check_provider_payable(provider_name, model_provider_rule)

# get stripe checkout product id
paid_provider = self._get_paid_provider(tenant_id, provider_name)
model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
model_provider = model_provider_class(provider=paid_provider)
payment_info = model_provider.get_payment_info()
if not payment_info:
raise ValueError(f'provider name {provider_name} not support payment')

payment_product_id = payment_info['product_id']
payment_min_quantity = payment_info['min_quantity']
payment_max_quantity = payment_info['max_quantity']

# create provider order
provider_order = ProviderOrder(
tenant_id=tenant_id,
provider_name=provider_name,
account_id=account.id,
payment_product_id=payment_product_id,
quantity=1,
payment_status=ProviderOrderPaymentStatus.WAIT_PAY.value
)

db.session.add(provider_order)
db.session.flush()

line_item = {
'price': f'{payment_product_id}',
'quantity': payment_min_quantity
}

if payment_min_quantity > 1 and payment_max_quantity != payment_min_quantity:
line_item['adjustable_quantity'] = {
'enabled': True,
'minimum': payment_min_quantity,
'maximum': payment_max_quantity
}

try:
# create stripe checkout session
checkout_session = stripe.checkout.Session.create(
line_items=[
line_item
],
mode='payment',
success_url=current_app.config.get("CONSOLE_WEB_URL")
+ f'?provider_name={provider_name}&payment_result=succeeded',
cancel_url=current_app.config.get("CONSOLE_WEB_URL")
+ f'?provider_name={provider_name}&payment_result=cancelled',
automatic_tax={'enabled': True},
)
except Exception as e:
logging.exception(e)
raise ValueError(f'provider name {provider_name} create checkout session failed, please try again later')

provider_order.payment_id = checkout_session.id
db.session.commit()

return ProviderCheckout(checkout_session)

def fulfill_provider_order(self, event, line_items):
provider_order = db.session.query(ProviderOrder) \
.filter(ProviderOrder.payment_id == event['data']['object']['id']) \
.first()

if not provider_order:
raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}')

if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value:
raise ValueError(
f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}')

provider_order.transaction_id = event['data']['object']['payment_intent']
provider_order.currency = event['data']['object']['currency']
provider_order.total_amount = event['data']['object']['amount_subtotal']
provider_order.payment_status = ProviderOrderPaymentStatus.PAID.value
provider_order.paid_at = datetime.datetime.utcnow()
provider_order.updated_at = provider_order.paid_at

# update provider quota
provider = db.session.query(Provider).filter(
Provider.tenant_id == provider_order.tenant_id,
Provider.provider_name == provider_order.provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.PAID.value
).first()

if not provider:
raise ValueError(f'provider not found, tenant id: {provider_order.tenant_id}, '
f'provider name: {provider_order.provider_name}')

model_provider_class = ModelProviderFactory.get_model_provider_class(provider_order.provider_name)
model_provider = model_provider_class(provider=provider)
payment_info = model_provider.get_payment_info()

quantity = line_items['data'][0]['quantity']

if not payment_info:
increase_quota = 0
else:
increase_quota = int(payment_info['increase_quota']) * quantity

if increase_quota > 0:
provider.quota_limit += increase_quota
provider.is_valid = True

db.session.commit()

def _check_provider_payable(self, provider_name: str, model_provider_rule: dict):
if ProviderType.SYSTEM.value not in model_provider_rule['support_provider_types']:
raise ValueError(f'provider name {provider_name} not support payment')

if 'system_config' not in model_provider_rule:
raise ValueError(f'provider name {provider_name} not support payment')

if 'supported_quota_types' not in model_provider_rule['system_config']:
raise ValueError(f'provider name {provider_name} not support payment')

if 'paid' not in model_provider_rule['system_config']['supported_quota_types']:
raise ValueError(f'provider name {provider_name} not support payment')

def _get_paid_provider(self, tenant_id: str, provider_name: str):
paid_provider = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.PAID.value,
).first()

if not paid_provider:
paid_provider = Provider(
tenant_id=tenant_id,
provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.PAID.value,
quota_limit=0,
quota_used=0,
)
db.session.add(paid_provider)
db.session.commit()

return paid_provider

読み込み中…
キャンセル
保存