瀏覽代碼

Feat/enterprise sso (#3602)

tags/0.6.4
Garfield Dai 1 年之前
父節點
當前提交
4481906be2
沒有連結到貢獻者的電子郵件帳戶。

+ 4
- 1
api/app.py 查看文件

@login_manager.request_loader @login_manager.request_loader
def load_user_from_request(request_from_flask_login): def load_user_from_request(request_from_flask_login):
"""Load user based on the request.""" """Load user based on the request."""
if request.blueprint == 'console':
if request.blueprint in ['console', 'inner_api']:
# Check if the user_id contains a dot, indicating the old format # Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get('Authorization', '') auth_header = request.headers.get('Authorization', '')
if not auth_header: if not auth_header:
from controllers.files import bp as files_bp from controllers.files import bp as files_bp
from controllers.service_api import bp as service_api_bp from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp from controllers.web import bp as web_bp
from controllers.inner_api import bp as inner_api_bp


CORS(service_api_bp, CORS(service_api_bp,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
) )
app.register_blueprint(files_bp) app.register_blueprint(files_bp)


app.register_blueprint(inner_api_bp)



# create app # create app
app = create_app() app = create_app()

+ 9
- 0
api/config.py 查看文件

'TOOL_ICON_CACHE_MAX_AGE': 3600, 'TOOL_ICON_CACHE_MAX_AGE': 3600,
'MILVUS_DATABASE': 'default', 'MILVUS_DATABASE': 'default',
'KEYWORD_DATA_SOURCE_TYPE': 'database', 'KEYWORD_DATA_SOURCE_TYPE': 'database',
'INNER_API': 'False',
'ENTERPRISE_ENABLED': 'False',
} }




# Alternatively you can set it with `SECRET_KEY` environment variable. # Alternatively you can set it with `SECRET_KEY` environment variable.
self.SECRET_KEY = get_env('SECRET_KEY') self.SECRET_KEY = get_env('SECRET_KEY')


# Enable or disable the inner API.
self.INNER_API = get_bool_env('INNER_API')
# The inner API key is used to authenticate the inner API.
self.INNER_API_KEY = get_env('INNER_API_KEY')

# cors settings # cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL) 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')


self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE') self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')



class CloudEditionConfig(Config): class CloudEditionConfig(Config):



+ 3
- 1
api/controllers/console/__init__.py 查看文件

from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app, from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app,
saved_message, workflow) saved_message, workflow)
# Import workspace controllers # Import workspace controllers
from .workspace import account, members, model_providers, models, tool_providers, workspace
from .workspace import account, members, model_providers, models, tool_providers, workspace
# Import enterprise controllers
from .enterprise import enterprise_sso

+ 6
- 3
api/controllers/console/auth/login.py 查看文件



try: try:
account = AccountService.authenticate(args['email'], args['password']) account = AccountService.authenticate(args['email'], args['password'])
except services.errors.account.AccountLoginError:
return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
except services.errors.account.AccountLoginError as e:
return {'code': 'unauthorized', 'message': str(e)}, 401


TenantService.create_owner_tenant_if_not_exist(account)
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}


AccountService.update_last_login(account, request) AccountService.update_last_login(account, request)



+ 0
- 0
api/controllers/console/enterprise/__init__.py 查看文件


+ 59
- 0
api/controllers/console/enterprise/enterprise_sso.py 查看文件

from flask import current_app, redirect
from flask_restful import Resource, reqparse

from controllers.console import api
from controllers.console.setup import setup_required
from services.enterprise.enterprise_sso_service import EnterpriseSSOService


class EnterpriseSSOSamlLogin(Resource):

@setup_required
def get(self):
return EnterpriseSSOService.get_sso_saml_login()


class EnterpriseSSOSamlAcs(Resource):

@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('SAMLResponse', type=str, required=True, location='form')
args = parser.parse_args()
saml_response = args['SAMLResponse']

try:
token = EnterpriseSSOService.post_sso_saml_acs(saml_response)
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
except Exception as e:
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')


class EnterpriseSSOOidcLogin(Resource):

@setup_required
def get(self):
return EnterpriseSSOService.get_sso_oidc_login()


class EnterpriseSSOOidcCallback(Resource):

@setup_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('state', type=str, required=True, location='args')
parser.add_argument('code', type=str, required=True, location='args')
parser.add_argument('oidc-state', type=str, required=True, location='cookies')
args = parser.parse_args()

try:
token = EnterpriseSSOService.get_sso_oidc_callback(args)
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
except Exception as e:
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')


api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login')
api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs')
api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login')
api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback')

+ 7
- 0
api/controllers/console/feature.py 查看文件

from flask_login import current_user from flask_login import current_user
from flask_restful import Resource from flask_restful import Resource


from services.enterprise.enterprise_feature_service import EnterpriseFeatureService
from services.feature_service import FeatureService from services.feature_service import FeatureService


from . import api from . import api
return FeatureService.get_features(current_user.current_tenant_id).dict() return FeatureService.get_features(current_user.current_tenant_id).dict()




class EnterpriseFeatureApi(Resource):
def get(self):
return EnterpriseFeatureService.get_enterprise_features().dict()


api.add_resource(FeatureApi, '/features') api.add_resource(FeatureApi, '/features')
api.add_resource(EnterpriseFeatureApi, '/enterprise-features')

+ 2
- 0
api/controllers/console/setup.py 查看文件

password=args['password'] password=args['password']
) )


TenantService.create_owner_tenant_if_not_exist(account)

setup() setup()
AccountService.update_last_login(account, request) AccountService.update_last_login(account, request)



+ 12
- 1
api/controllers/console/workspace/workspace.py 查看文件

from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
from werkzeug.exceptions import Unauthorized


import services import services
from controllers.console import api from controllers.console import api
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from libs.helper import TimestampField
from libs.login import login_required from libs.login import login_required
from models.account import Tenant
from models.account import Tenant, TenantStatus
from services.account_service import TenantService from services.account_service import TenantService
from services.file_service import FileService from services.file_service import FileService
from services.workspace_service import WorkspaceService from services.workspace_service import WorkspaceService


tenant = current_user.current_tenant tenant = current_user.current_tenant


if tenant.status == TenantStatus.ARCHIVE:
tenants = TenantService.get_join_tenants(current_user)
# if there is any tenant, switch to the first one
if len(tenants) > 0:
TenantService.switch_tenant(current_user, tenants[0].id)
tenant = tenants[0]
# else, raise Unauthorized
else:
raise Unauthorized('workspace is archived')

return WorkspaceService.get_tenant_info(tenant), 200 return WorkspaceService.get_tenant_info(tenant), 200





+ 8
- 0
api/controllers/inner_api/__init__.py 查看文件

from flask import Blueprint
from libs.external_api import ExternalApi

bp = Blueprint('inner_api', __name__, url_prefix='/inner/api')
api = ExternalApi(bp)

from .workspace import workspace


+ 0
- 0
api/controllers/inner_api/workspace/__init__.py 查看文件


+ 37
- 0
api/controllers/inner_api/workspace/workspace.py 查看文件

from flask_restful import Resource, reqparse

from controllers.console.setup import setup_required
from controllers.inner_api import api
from controllers.inner_api.wraps import inner_api_only
from events.tenant_event import tenant_was_created
from models.account import Account
from services.account_service import TenantService


class EnterpriseWorkspace(Resource):

@setup_required
@inner_api_only
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('owner_email', type=str, required=True, location='json')
args = parser.parse_args()

account = Account.query.filter_by(email=args['owner_email']).first()
if account is None:
return {
'message': 'owner account not found.'
}, 404

tenant = TenantService.create_tenant(args['name'])
TenantService.create_tenant_member(tenant, account, role='owner')

tenant_was_created.send(tenant)

return {
'message': 'enterprise workspace created.'
}


api.add_resource(EnterpriseWorkspace, '/enterprise/workspace')

+ 61
- 0
api/controllers/inner_api/wraps.py 查看文件

from base64 import b64encode
from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new

from flask import abort, current_app, request

from extensions.ext_database import db
from models.model import EndUser


def inner_api_only(view):
@wraps(view)
def decorated(*args, **kwargs):
if not current_app.config['INNER_API']:
abort(404)

# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get('X-Inner-Api-Key')
if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']:
abort(404)

return view(*args, **kwargs)

return decorated


def inner_api_user_auth(view):
@wraps(view)
def decorated(*args, **kwargs):
if not current_app.config['INNER_API']:
return view(*args, **kwargs)

# get header 'X-Inner-Api-Key'
authorization = request.headers.get('Authorization')
if not authorization:
return view(*args, **kwargs)

parts = authorization.split(':')
if len(parts) != 2:
return view(*args, **kwargs)

user_id, token = parts
if ' ' in user_id:
user_id = user_id.split(' ')[1]

inner_api_key = request.headers.get('X-Inner-Api-Key')

data_to_sign = f'DIFY {user_id}'

signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1)
signature = b64encode(signature.digest()).decode('utf-8')

if signature != token:
return view(*args, **kwargs)

kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first()

return view(*args, **kwargs)

return decorated

+ 6
- 1
api/controllers/service_api/wraps.py 查看文件



from extensions.ext_database import db from extensions.ext_database import db
from libs.login import _get_user from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.model import ApiToken, App, EndUser from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService from services.feature_service import FeatureService


if not app_model.enable_api: if not app_model.enable_api:
raise NotFound() raise NotFound()


tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
if tenant.status == TenantStatus.ARCHIVE:
raise NotFound()

kwargs['app_model'] = app_model kwargs['app_model'] = app_model


if fetch_user_arg: if fetch_user_arg:
.filter(Tenant.id == api_token.tenant_id) \ .filter(Tenant.id == api_token.tenant_id) \
.filter(TenantAccountJoin.tenant_id == Tenant.id) \ .filter(TenantAccountJoin.tenant_id == Tenant.id) \
.filter(TenantAccountJoin.role.in_(['owner'])) \ .filter(TenantAccountJoin.role.in_(['owner'])) \
.filter(Tenant.status == TenantStatus.NORMAL) \
.one_or_none() # TODO: only owner information is required, so only one is returned. .one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join: if tenant_account_join:
tenant, ta = tenant_account_join tenant, ta = tenant_account_join

+ 4
- 0
api/controllers/web/site.py 查看文件

from controllers.web import api from controllers.web import api
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from extensions.ext_database import db from extensions.ext_database import db
from models.account import TenantStatus
from models.model import Site from models.model import Site
from services.feature_service import FeatureService from services.feature_service import FeatureService


if not site: if not site:
raise Forbidden() raise Forbidden()


if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()

can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo


return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)

+ 6
- 0
api/models/account.py 查看文件

def is_admin_or_owner(self): def is_admin_or_owner(self):
return self._current_tenant.current_role in ['admin', 'owner'] return self._current_tenant.current_role in ['admin', 'owner']



class TenantStatus(str, enum.Enum):
NORMAL = 'normal'
ARCHIVE = 'archive'


class Tenant(db.Model): class Tenant(db.Model):
__tablename__ = 'tenants' __tablename__ = 'tenants'
__table_args__ = ( __table_args__ = (

+ 9
- 4
api/services/account_service.py 查看文件



from flask import current_app from flask import current_app
from sqlalchemy import func from sqlalchemy import func
from werkzeug.exceptions import Forbidden
from werkzeug.exceptions import Unauthorized


from constants.languages import language_timezone_mapping, languages from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
return None return None


if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
raise Forbidden('Account is banned or closed.')
raise Unauthorized("Account is banned or closed.")


current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
if current_tenant: if current_tenant:
"""Get account join tenants""" """Get account join tenants"""
return db.session.query(Tenant).join( return db.session.query(Tenant).join(
TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id
).filter(TenantAccountJoin.account_id == account.id).all()
).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all()


@staticmethod @staticmethod
def get_current_tenant_by_account(account: Account): def get_current_tenant_by_account(account: Account):
if tenant_id is None: if tenant_id is None:
raise ValueError("Tenant ID must be provided.") raise ValueError("Tenant ID must be provided.")


tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == tenant_id,
Tenant.status == TenantStatus.NORMAL,
).first()

if not tenant_account_join: if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else: else:

+ 0
- 0
api/services/enterprise/__init__.py 查看文件


+ 20
- 0
api/services/enterprise/base.py 查看文件

import os

import requests


class EnterpriseRequest:
base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL')
secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY')

@classmethod
def send_request(cls, method, endpoint, json=None, params=None):
headers = {
"Content-Type": "application/json",
"Enterprise-Api-Secret-Key": cls.secret_key
}

url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers)

return response.json()

+ 28
- 0
api/services/enterprise/enterprise_feature_service.py 查看文件

from flask import current_app
from pydantic import BaseModel

from services.enterprise.enterprise_service import EnterpriseService


class EnterpriseFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ''


class EnterpriseFeatureService:

@classmethod
def get_enterprise_features(cls) -> EnterpriseFeatureModel:
features = EnterpriseFeatureModel()

if current_app.config['ENTERPRISE_ENABLED']:
cls._fulfill_params_from_enterprise(features)

return features

@classmethod
def _fulfill_params_from_enterprise(cls, features):
enterprise_info = EnterpriseService.get_info()

features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']

+ 8
- 0
api/services/enterprise/enterprise_service.py 查看文件

from services.enterprise.base import EnterpriseRequest


class EnterpriseService:

@classmethod
def get_info(cls):
return EnterpriseRequest.send_request('GET', '/info')

+ 60
- 0
api/services/enterprise/enterprise_sso_service.py 查看文件

import logging

from models.account import Account, AccountStatus
from services.account_service import AccountService, TenantService
from services.enterprise.base import EnterpriseRequest

logger = logging.getLogger(__name__)


class EnterpriseSSOService:

@classmethod
def get_sso_saml_login(cls) -> str:
return EnterpriseRequest.send_request('GET', '/sso/saml/login')

@classmethod
def post_sso_saml_acs(cls, saml_response: str) -> str:
response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response})
if 'email' not in response or response['email'] is None:
logger.exception(response)
raise Exception('Saml response is invalid')

return cls.login_with_email(response.get('email'))

@classmethod
def get_sso_oidc_login(cls):
return EnterpriseRequest.send_request('GET', '/sso/oidc/login')

@classmethod
def get_sso_oidc_callback(cls, args: dict):
state_from_query = args['state']
code_from_query = args['code']
state_from_cookies = args['oidc-state']

if state_from_cookies != state_from_query:
raise Exception('invalid state or code')

response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query})
if 'email' not in response or response['email'] is None:
logger.exception(response)
raise Exception('OIDC response is invalid')

return cls.login_with_email(response.get('email'))

@classmethod
def login_with_email(cls, email: str) -> str:
account = Account.query.filter_by(email=email).first()
if account is None:
raise Exception('account not found, please contact system admin to invite you to join in a workspace')

if account.status == AccountStatus.BANNED:
raise Exception('account is banned, please contact system admin')

tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
raise Exception("workspace not found, please contact system admin to invite you to join in a workspace")

token = AccountService.get_account_jwt_token(account)

return token

+ 4
- 0
web/app/components/header/account-dropdown/index.tsx 查看文件

url: '/logout', url: '/logout',
params: {}, params: {},
}) })

if (localStorage?.getItem('console_token'))
localStorage.removeItem('console_token')

router.push('/signin') router.push('/signin')
} }



+ 0
- 3
web/app/signin/_header.tsx 查看文件

const Header = () => { const Header = () => {
const { locale, setLocaleOnClient } = useContext(I18n) const { locale, setLocaleOnClient } = useContext(I18n)


if (localStorage?.getItem('console_token'))
localStorage.removeItem('console_token')

return <div className='flex items-center justify-between p-6 w-full'> return <div className='flex items-center justify-between p-6 w-full'>
<LogoSite /> <LogoSite />
<Select <Select

+ 87
- 0
web/app/signin/enterpriseSSOForm.tsx 查看文件

'use client'
import cn from 'classnames'
import { useRouter, useSearchParams } from 'next/navigation'
import type { FC } from 'react'
import { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Toast from '@/app/components/base/toast'
import { getOIDCSSOUrl, getSAMLSSOUrl } from '@/service/enterprise'
import Button from '@/app/components/base/button'

type EnterpriseSSOFormProps = {
protocol: string
}

const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({
protocol,
}) => {
const searchParams = useSearchParams()
const consoleToken = searchParams.get('console_token')
const message = searchParams.get('message')

const router = useRouter()
const { t } = useTranslation()

const [isLoading, setIsLoading] = useState(false)

useEffect(() => {
if (consoleToken) {
localStorage.setItem('console_token', consoleToken)
router.replace('/apps')
}

if (message) {
Toast.notify({
type: 'error',
message,
})
}
}, [])

const handleSSOLogin = () => {
setIsLoading(true)
if (protocol === 'saml') {
getSAMLSSOUrl().then((res) => {
router.push(res.url)
}).finally(() => {
setIsLoading(false)
})
}
else {
getOIDCSSOUrl().then((res) => {
document.cookie = `oidc-state=${res.state}`
router.push(res.url)
}).finally(() => {
setIsLoading(false)
})
}
}

return (
<div className={
cn(
'flex flex-col items-center w-full grow items-center justify-center',
'px-6',
'md:px-[108px]',
)
}>
<div className='flex flex-col md:w-[400px]'>
<div className="w-full mx-auto">
<h2 className="text-[32px] font-bold text-gray-900">{t('login.pageTitle')}</h2>
</div>
<div className="w-full mx-auto mt-10">
<Button
tabIndex={0}
type='primary'
onClick={() => { handleSSOLogin() }}
disabled={isLoading}
className="w-full !fone-medium !text-sm"
>{t('login.sso')}
</Button>
</div>
</div>
</div>
)
}

export default EnterpriseSSOForm

+ 11
- 2
web/app/signin/normalForm.tsx 查看文件

remember_me: true, remember_me: true,
}, },
}) })
localStorage.setItem('console_token', res.data)
router.replace('/apps')

if (res.result === 'success') {
localStorage.setItem('console_token', res.data)
router.replace('/apps')
}
else {
Toast.notify({
type: 'error',
message: res.data,
})
}
} }
finally { finally {
setIsLoading(false) setIsLoading(false)

+ 43
- 5
web/app/signin/page.tsx 查看文件

import React from 'react'
'use client'
import React, { useEffect, useState } from 'react'
import cn from 'classnames' import cn from 'classnames'
import Script from 'next/script' import Script from 'next/script'
import Loading from '../components/base/loading'
import Forms from './forms' import Forms from './forms'
import Header from './_header' import Header from './_header'
import style from './page.module.css' import style from './page.module.css'
import EnterpriseSSOForm from './enterpriseSSOForm'
import { IS_CE_EDITION } from '@/config' import { IS_CE_EDITION } from '@/config'
import { getEnterpriseFeatures } from '@/service/enterprise'
import type { EnterpriseFeatures } from '@/types/enterprise'
import { defaultEnterpriseFeatures } from '@/types/enterprise'


const SignIn = () => { const SignIn = () => {
const [loading, setLoading] = useState<boolean>(true)
const [enterpriseFeatures, setEnterpriseFeatures] = useState<EnterpriseFeatures>(defaultEnterpriseFeatures)

useEffect(() => {
getEnterpriseFeatures().then((res) => {
setEnterpriseFeatures(res)
}).finally(() => {
setLoading(false)
})
}, [])

return ( return (
<> <>
{!IS_CE_EDITION && ( {!IS_CE_EDITION && (
) )
}> }>
<Header /> <Header />
<Forms />
<div className='px-8 py-6 text-sm font-normal text-gray-500'>
© {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
</div>

{loading && (
<div className={
cn(
'flex flex-col items-center w-full grow items-center justify-center',
'px-6',
'md:px-[108px]',
)
}>
<Loading type='area' />
</div>
)}

{!loading && !enterpriseFeatures.sso_enforced_for_signin && (
<>
<Forms />
<div className='px-8 py-6 text-sm font-normal text-gray-500'>
© {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
</div>
</>
)}

{!loading && enterpriseFeatures.sso_enforced_for_signin && (
<EnterpriseSSOForm protocol={enterpriseFeatures.sso_enforced_for_signin_protocol} />
)}
</div> </div>


</div> </div>

+ 1
- 0
web/i18n/en-US/login.ts 查看文件

namePlaceholder: 'Your username', namePlaceholder: 'Your username',
forget: 'Forgot your password?', forget: 'Forgot your password?',
signBtn: 'Sign in', signBtn: 'Sign in',
sso: 'Continue with SSO',
installBtn: 'Set up', installBtn: 'Set up',
setAdminAccount: 'Setting up an admin account', setAdminAccount: 'Setting up an admin account',
setAdminAccountDesc: 'Maximum privileges for admin account, which can be used to create applications and manage LLM providers, etc.', setAdminAccountDesc: 'Maximum privileges for admin account, which can be used to create applications and manage LLM providers, etc.',

+ 14
- 0
web/service/enterprise.ts 查看文件

import { get } from './base'
import type { EnterpriseFeatures } from '@/types/enterprise'

export const getEnterpriseFeatures = () => {
return get<EnterpriseFeatures>('/enterprise-features')
}

export const getSAMLSSOUrl = () => {
return get<{ url: string }>('/enterprise/sso/saml/login')
}

export const getOIDCSSOUrl = () => {
return get<{ url: string; state: string }>('/enterprise/sso/oidc/login')
}

+ 9
- 0
web/types/enterprise.ts 查看文件

export type EnterpriseFeatures = {
sso_enforced_for_signin: boolean
sso_enforced_for_signin_protocol: string
}

export const defaultEnterpriseFeatures: EnterpriseFeatures = {
sso_enforced_for_signin: false,
sso_enforced_for_signin_protocol: '',
}

Loading…
取消
儲存