| @@ -115,7 +115,7 @@ def initialize_extensions(app): | |||
| @login_manager.request_loader | |||
| def load_user_from_request(request_from_flask_login): | |||
| """Load user based on the request.""" | |||
| if request.blueprint == 'console': | |||
| if request.blueprint in ['console', 'inner_api']: | |||
| # Check if the user_id contains a dot, indicating the old format | |||
| auth_header = request.headers.get('Authorization', '') | |||
| if not auth_header: | |||
| @@ -153,6 +153,7 @@ def register_blueprints(app): | |||
| from controllers.files import bp as files_bp | |||
| from controllers.service_api import bp as service_api_bp | |||
| from controllers.web import bp as web_bp | |||
| from controllers.inner_api import bp as inner_api_bp | |||
| CORS(service_api_bp, | |||
| allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], | |||
| @@ -188,6 +189,8 @@ def register_blueprints(app): | |||
| ) | |||
| app.register_blueprint(files_bp) | |||
| app.register_blueprint(inner_api_bp) | |||
| # create app | |||
| app = create_app() | |||
| @@ -69,6 +69,8 @@ DEFAULTS = { | |||
| 'TOOL_ICON_CACHE_MAX_AGE': 3600, | |||
| 'MILVUS_DATABASE': 'default', | |||
| 'KEYWORD_DATA_SOURCE_TYPE': 'database', | |||
| 'INNER_API': 'False', | |||
| 'ENTERPRISE_ENABLED': 'False', | |||
| } | |||
| @@ -133,6 +135,11 @@ class Config: | |||
| # Alternatively you can set it with `SECRET_KEY` environment variable. | |||
| self.SECRET_KEY = get_env('SECRET_KEY') | |||
| # Enable or disable the inner API. | |||
| self.INNER_API = get_bool_env('INNER_API') | |||
| # The inner API key is used to authenticate the inner API. | |||
| self.INNER_API_KEY = get_env('INNER_API_KEY') | |||
| # cors settings | |||
| self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( | |||
| 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL) | |||
| @@ -327,6 +334,8 @@ class Config: | |||
| self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') | |||
| self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE') | |||
| self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED') | |||
| class CloudEditionConfig(Config): | |||
| @@ -19,4 +19,6 @@ from .datasets import data_source, datasets, datasets_document, datasets_segment | |||
| from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app, | |||
| saved_message, workflow) | |||
| # Import workspace controllers | |||
| from .workspace import account, members, model_providers, models, tool_providers, workspace | |||
| from .workspace import account, members, model_providers, models, tool_providers, workspace | |||
| # Import enterprise controllers | |||
| from .enterprise import enterprise_sso | |||
| @@ -26,10 +26,13 @@ class LoginApi(Resource): | |||
| try: | |||
| account = AccountService.authenticate(args['email'], args['password']) | |||
| except services.errors.account.AccountLoginError: | |||
| return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401 | |||
| except services.errors.account.AccountLoginError as e: | |||
| return {'code': 'unauthorized', 'message': str(e)}, 401 | |||
| TenantService.create_owner_tenant_if_not_exist(account) | |||
| # SELF_HOSTED only have one workspace | |||
| tenants = TenantService.get_join_tenants(account) | |||
| if len(tenants) == 0: | |||
| return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} | |||
| AccountService.update_last_login(account, request) | |||
| @@ -0,0 +1,59 @@ | |||
| from flask import current_app, redirect | |||
| from flask_restful import Resource, reqparse | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from services.enterprise.enterprise_sso_service import EnterpriseSSOService | |||
| class EnterpriseSSOSamlLogin(Resource): | |||
| @setup_required | |||
| def get(self): | |||
| return EnterpriseSSOService.get_sso_saml_login() | |||
| class EnterpriseSSOSamlAcs(Resource): | |||
| @setup_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('SAMLResponse', type=str, required=True, location='form') | |||
| args = parser.parse_args() | |||
| saml_response = args['SAMLResponse'] | |||
| try: | |||
| token = EnterpriseSSOService.post_sso_saml_acs(saml_response) | |||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') | |||
| except Exception as e: | |||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') | |||
| class EnterpriseSSOOidcLogin(Resource): | |||
| @setup_required | |||
| def get(self): | |||
| return EnterpriseSSOService.get_sso_oidc_login() | |||
| class EnterpriseSSOOidcCallback(Resource): | |||
| @setup_required | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('state', type=str, required=True, location='args') | |||
| parser.add_argument('code', type=str, required=True, location='args') | |||
| parser.add_argument('oidc-state', type=str, required=True, location='cookies') | |||
| args = parser.parse_args() | |||
| try: | |||
| token = EnterpriseSSOService.get_sso_oidc_callback(args) | |||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') | |||
| except Exception as e: | |||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') | |||
| api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login') | |||
| api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs') | |||
| api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login') | |||
| api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback') | |||
| @@ -1,6 +1,7 @@ | |||
| from flask_login import current_user | |||
| from flask_restful import Resource | |||
| from services.enterprise.enterprise_feature_service import EnterpriseFeatureService | |||
| from services.feature_service import FeatureService | |||
| from . import api | |||
| @@ -14,4 +15,10 @@ class FeatureApi(Resource): | |||
| return FeatureService.get_features(current_user.current_tenant_id).dict() | |||
| class EnterpriseFeatureApi(Resource): | |||
| def get(self): | |||
| return EnterpriseFeatureService.get_enterprise_features().dict() | |||
| api.add_resource(FeatureApi, '/features') | |||
| api.add_resource(EnterpriseFeatureApi, '/enterprise-features') | |||
| @@ -58,6 +58,8 @@ class SetupApi(Resource): | |||
| password=args['password'] | |||
| ) | |||
| TenantService.create_owner_tenant_if_not_exist(account) | |||
| setup() | |||
| AccountService.update_last_login(account, request) | |||
| @@ -3,6 +3,7 @@ import logging | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse | |||
| from werkzeug.exceptions import Unauthorized | |||
| import services | |||
| from controllers.console import api | |||
| @@ -19,7 +20,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi | |||
| from extensions.ext_database import db | |||
| from libs.helper import TimestampField | |||
| from libs.login import login_required | |||
| from models.account import Tenant | |||
| from models.account import Tenant, TenantStatus | |||
| from services.account_service import TenantService | |||
| from services.file_service import FileService | |||
| from services.workspace_service import WorkspaceService | |||
| @@ -116,6 +117,16 @@ class TenantApi(Resource): | |||
| tenant = current_user.current_tenant | |||
| if tenant.status == TenantStatus.ARCHIVE: | |||
| tenants = TenantService.get_join_tenants(current_user) | |||
| # if there is any tenant, switch to the first one | |||
| if len(tenants) > 0: | |||
| TenantService.switch_tenant(current_user, tenants[0].id) | |||
| tenant = tenants[0] | |||
| # else, raise Unauthorized | |||
| else: | |||
| raise Unauthorized('workspace is archived') | |||
| return WorkspaceService.get_tenant_info(tenant), 200 | |||
| @@ -0,0 +1,8 @@ | |||
| from flask import Blueprint | |||
| from libs.external_api import ExternalApi | |||
| bp = Blueprint('inner_api', __name__, url_prefix='/inner/api') | |||
| api = ExternalApi(bp) | |||
| from .workspace import workspace | |||
| @@ -0,0 +1,37 @@ | |||
| from flask_restful import Resource, reqparse | |||
| from controllers.console.setup import setup_required | |||
| from controllers.inner_api import api | |||
| from controllers.inner_api.wraps import inner_api_only | |||
| from events.tenant_event import tenant_was_created | |||
| from models.account import Account | |||
| from services.account_service import TenantService | |||
| class EnterpriseWorkspace(Resource): | |||
| @setup_required | |||
| @inner_api_only | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=True, location='json') | |||
| parser.add_argument('owner_email', type=str, required=True, location='json') | |||
| args = parser.parse_args() | |||
| account = Account.query.filter_by(email=args['owner_email']).first() | |||
| if account is None: | |||
| return { | |||
| 'message': 'owner account not found.' | |||
| }, 404 | |||
| tenant = TenantService.create_tenant(args['name']) | |||
| TenantService.create_tenant_member(tenant, account, role='owner') | |||
| tenant_was_created.send(tenant) | |||
| return { | |||
| 'message': 'enterprise workspace created.' | |||
| } | |||
| api.add_resource(EnterpriseWorkspace, '/enterprise/workspace') | |||
| @@ -0,0 +1,61 @@ | |||
| from base64 import b64encode | |||
| from functools import wraps | |||
| from hashlib import sha1 | |||
| from hmac import new as hmac_new | |||
| from flask import abort, current_app, request | |||
| from extensions.ext_database import db | |||
| from models.model import EndUser | |||
| def inner_api_only(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if not current_app.config['INNER_API']: | |||
| abort(404) | |||
| # get header 'X-Inner-Api-Key' | |||
| inner_api_key = request.headers.get('X-Inner-Api-Key') | |||
| if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']: | |||
| abort(404) | |||
| return view(*args, **kwargs) | |||
| return decorated | |||
| def inner_api_user_auth(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if not current_app.config['INNER_API']: | |||
| return view(*args, **kwargs) | |||
| # get header 'X-Inner-Api-Key' | |||
| authorization = request.headers.get('Authorization') | |||
| if not authorization: | |||
| return view(*args, **kwargs) | |||
| parts = authorization.split(':') | |||
| if len(parts) != 2: | |||
| return view(*args, **kwargs) | |||
| user_id, token = parts | |||
| if ' ' in user_id: | |||
| user_id = user_id.split(' ')[1] | |||
| inner_api_key = request.headers.get('X-Inner-Api-Key') | |||
| data_to_sign = f'DIFY {user_id}' | |||
| signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1) | |||
| signature = b64encode(signature.digest()).decode('utf-8') | |||
| if signature != token: | |||
| return view(*args, **kwargs) | |||
| kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first() | |||
| return view(*args, **kwargs) | |||
| return decorated | |||
| @@ -12,7 +12,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized | |||
| from extensions.ext_database import db | |||
| from libs.login import _get_user | |||
| from models.account import Account, Tenant, TenantAccountJoin | |||
| from models.account import Account, Tenant, TenantAccountJoin, TenantStatus | |||
| from models.model import ApiToken, App, EndUser | |||
| from services.feature_service import FeatureService | |||
| @@ -47,6 +47,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| if not app_model.enable_api: | |||
| raise NotFound() | |||
| tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() | |||
| if tenant.status == TenantStatus.ARCHIVE: | |||
| raise NotFound() | |||
| kwargs['app_model'] = app_model | |||
| if fetch_user_arg: | |||
| @@ -137,6 +141,7 @@ def validate_dataset_token(view=None): | |||
| .filter(Tenant.id == api_token.tenant_id) \ | |||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) \ | |||
| .filter(TenantAccountJoin.role.in_(['owner'])) \ | |||
| .filter(Tenant.status == TenantStatus.NORMAL) \ | |||
| .one_or_none() # TODO: only owner information is required, so only one is returned. | |||
| if tenant_account_join: | |||
| tenant, ta = tenant_account_join | |||
| @@ -6,6 +6,7 @@ from werkzeug.exceptions import Forbidden | |||
| from controllers.web import api | |||
| from controllers.web.wraps import WebApiResource | |||
| from extensions.ext_database import db | |||
| from models.account import TenantStatus | |||
| from models.model import Site | |||
| from services.feature_service import FeatureService | |||
| @@ -54,6 +55,9 @@ class AppSiteApi(WebApiResource): | |||
| if not site: | |||
| raise Forbidden() | |||
| if app_model.tenant.status == TenantStatus.ARCHIVE: | |||
| raise Forbidden() | |||
| can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo | |||
| return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) | |||
| @@ -105,6 +105,12 @@ class Account(UserMixin, db.Model): | |||
| def is_admin_or_owner(self): | |||
| return self._current_tenant.current_role in ['admin', 'owner'] | |||
| class TenantStatus(str, enum.Enum): | |||
| NORMAL = 'normal' | |||
| ARCHIVE = 'archive' | |||
| class Tenant(db.Model): | |||
| __tablename__ = 'tenants' | |||
| __table_args__ = ( | |||
| @@ -8,7 +8,7 @@ from typing import Any, Optional | |||
| from flask import current_app | |||
| from sqlalchemy import func | |||
| from werkzeug.exceptions import Forbidden | |||
| from werkzeug.exceptions import Unauthorized | |||
| from constants.languages import language_timezone_mapping, languages | |||
| from events.tenant_event import tenant_was_created | |||
| @@ -44,7 +44,7 @@ class AccountService: | |||
| return None | |||
| if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: | |||
| raise Forbidden('Account is banned or closed.') | |||
| raise Unauthorized("Account is banned or closed.") | |||
| current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() | |||
| if current_tenant: | |||
| @@ -255,7 +255,7 @@ class TenantService: | |||
| """Get account join tenants""" | |||
| return db.session.query(Tenant).join( | |||
| TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id | |||
| ).filter(TenantAccountJoin.account_id == account.id).all() | |||
| ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all() | |||
| @staticmethod | |||
| def get_current_tenant_by_account(account: Account): | |||
| @@ -279,7 +279,12 @@ class TenantService: | |||
| if tenant_id is None: | |||
| raise ValueError("Tenant ID must be provided.") | |||
| tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first() | |||
| tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter( | |||
| TenantAccountJoin.account_id == account.id, | |||
| TenantAccountJoin.tenant_id == tenant_id, | |||
| Tenant.status == TenantStatus.NORMAL, | |||
| ).first() | |||
| if not tenant_account_join: | |||
| raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") | |||
| else: | |||
| @@ -0,0 +1,20 @@ | |||
| import os | |||
| import requests | |||
| class EnterpriseRequest: | |||
| base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL') | |||
| secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY') | |||
| @classmethod | |||
| def send_request(cls, method, endpoint, json=None, params=None): | |||
| headers = { | |||
| "Content-Type": "application/json", | |||
| "Enterprise-Api-Secret-Key": cls.secret_key | |||
| } | |||
| url = f"{cls.base_url}{endpoint}" | |||
| response = requests.request(method, url, json=json, params=params, headers=headers) | |||
| return response.json() | |||
| @@ -0,0 +1,28 @@ | |||
| from flask import current_app | |||
| from pydantic import BaseModel | |||
| from services.enterprise.enterprise_service import EnterpriseService | |||
| class EnterpriseFeatureModel(BaseModel): | |||
| sso_enforced_for_signin: bool = False | |||
| sso_enforced_for_signin_protocol: str = '' | |||
| class EnterpriseFeatureService: | |||
| @classmethod | |||
| def get_enterprise_features(cls) -> EnterpriseFeatureModel: | |||
| features = EnterpriseFeatureModel() | |||
| if current_app.config['ENTERPRISE_ENABLED']: | |||
| cls._fulfill_params_from_enterprise(features) | |||
| return features | |||
| @classmethod | |||
| def _fulfill_params_from_enterprise(cls, features): | |||
| enterprise_info = EnterpriseService.get_info() | |||
| features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] | |||
| features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] | |||
| @@ -0,0 +1,8 @@ | |||
| from services.enterprise.base import EnterpriseRequest | |||
| class EnterpriseService: | |||
| @classmethod | |||
| def get_info(cls): | |||
| return EnterpriseRequest.send_request('GET', '/info') | |||
| @@ -0,0 +1,60 @@ | |||
| import logging | |||
| from models.account import Account, AccountStatus | |||
| from services.account_service import AccountService, TenantService | |||
| from services.enterprise.base import EnterpriseRequest | |||
| logger = logging.getLogger(__name__) | |||
| class EnterpriseSSOService: | |||
| @classmethod | |||
| def get_sso_saml_login(cls) -> str: | |||
| return EnterpriseRequest.send_request('GET', '/sso/saml/login') | |||
| @classmethod | |||
| def post_sso_saml_acs(cls, saml_response: str) -> str: | |||
| response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response}) | |||
| if 'email' not in response or response['email'] is None: | |||
| logger.exception(response) | |||
| raise Exception('Saml response is invalid') | |||
| return cls.login_with_email(response.get('email')) | |||
| @classmethod | |||
| def get_sso_oidc_login(cls): | |||
| return EnterpriseRequest.send_request('GET', '/sso/oidc/login') | |||
| @classmethod | |||
| def get_sso_oidc_callback(cls, args: dict): | |||
| state_from_query = args['state'] | |||
| code_from_query = args['code'] | |||
| state_from_cookies = args['oidc-state'] | |||
| if state_from_cookies != state_from_query: | |||
| raise Exception('invalid state or code') | |||
| response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query}) | |||
| if 'email' not in response or response['email'] is None: | |||
| logger.exception(response) | |||
| raise Exception('OIDC response is invalid') | |||
| return cls.login_with_email(response.get('email')) | |||
| @classmethod | |||
| def login_with_email(cls, email: str) -> str: | |||
| account = Account.query.filter_by(email=email).first() | |||
| if account is None: | |||
| raise Exception('account not found, please contact system admin to invite you to join in a workspace') | |||
| if account.status == AccountStatus.BANNED: | |||
| raise Exception('account is banned, please contact system admin') | |||
| tenants = TenantService.get_join_tenants(account) | |||
| if len(tenants) == 0: | |||
| raise Exception("workspace not found, please contact system admin to invite you to join in a workspace") | |||
| token = AccountService.get_account_jwt_token(account) | |||
| return token | |||
| @@ -39,6 +39,10 @@ export default function AppSelector({ isMobile }: IAppSelecotr) { | |||
| url: '/logout', | |||
| params: {}, | |||
| }) | |||
| if (localStorage?.getItem('console_token')) | |||
| localStorage.removeItem('console_token') | |||
| router.push('/signin') | |||
| } | |||
| @@ -10,9 +10,6 @@ import LogoSite from '@/app/components/base/logo/logo-site' | |||
| const Header = () => { | |||
| const { locale, setLocaleOnClient } = useContext(I18n) | |||
| if (localStorage?.getItem('console_token')) | |||
| localStorage.removeItem('console_token') | |||
| return <div className='flex items-center justify-between p-6 w-full'> | |||
| <LogoSite /> | |||
| <Select | |||
| @@ -0,0 +1,87 @@ | |||
| 'use client' | |||
| import cn from 'classnames' | |||
| import { useRouter, useSearchParams } from 'next/navigation' | |||
| import type { FC } from 'react' | |||
| import { useEffect, useState } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import Toast from '@/app/components/base/toast' | |||
| import { getOIDCSSOUrl, getSAMLSSOUrl } from '@/service/enterprise' | |||
| import Button from '@/app/components/base/button' | |||
| type EnterpriseSSOFormProps = { | |||
| protocol: string | |||
| } | |||
| const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({ | |||
| protocol, | |||
| }) => { | |||
| const searchParams = useSearchParams() | |||
| const consoleToken = searchParams.get('console_token') | |||
| const message = searchParams.get('message') | |||
| const router = useRouter() | |||
| const { t } = useTranslation() | |||
| const [isLoading, setIsLoading] = useState(false) | |||
| useEffect(() => { | |||
| if (consoleToken) { | |||
| localStorage.setItem('console_token', consoleToken) | |||
| router.replace('/apps') | |||
| } | |||
| if (message) { | |||
| Toast.notify({ | |||
| type: 'error', | |||
| message, | |||
| }) | |||
| } | |||
| }, []) | |||
| const handleSSOLogin = () => { | |||
| setIsLoading(true) | |||
| if (protocol === 'saml') { | |||
| getSAMLSSOUrl().then((res) => { | |||
| router.push(res.url) | |||
| }).finally(() => { | |||
| setIsLoading(false) | |||
| }) | |||
| } | |||
| else { | |||
| getOIDCSSOUrl().then((res) => { | |||
| document.cookie = `oidc-state=${res.state}` | |||
| router.push(res.url) | |||
| }).finally(() => { | |||
| setIsLoading(false) | |||
| }) | |||
| } | |||
| } | |||
| return ( | |||
| <div className={ | |||
| cn( | |||
| 'flex flex-col items-center w-full grow items-center justify-center', | |||
| 'px-6', | |||
| 'md:px-[108px]', | |||
| ) | |||
| }> | |||
| <div className='flex flex-col md:w-[400px]'> | |||
| <div className="w-full mx-auto"> | |||
| <h2 className="text-[32px] font-bold text-gray-900">{t('login.pageTitle')}</h2> | |||
| </div> | |||
| <div className="w-full mx-auto mt-10"> | |||
| <Button | |||
| tabIndex={0} | |||
| type='primary' | |||
| onClick={() => { handleSSOLogin() }} | |||
| disabled={isLoading} | |||
| className="w-full !fone-medium !text-sm" | |||
| >{t('login.sso')} | |||
| </Button> | |||
| </div> | |||
| </div> | |||
| </div> | |||
| ) | |||
| } | |||
| export default EnterpriseSSOForm | |||
| @@ -96,8 +96,17 @@ const NormalForm = () => { | |||
| remember_me: true, | |||
| }, | |||
| }) | |||
| localStorage.setItem('console_token', res.data) | |||
| router.replace('/apps') | |||
| if (res.result === 'success') { | |||
| localStorage.setItem('console_token', res.data) | |||
| router.replace('/apps') | |||
| } | |||
| else { | |||
| Toast.notify({ | |||
| type: 'error', | |||
| message: res.data, | |||
| }) | |||
| } | |||
| } | |||
| finally { | |||
| setIsLoading(false) | |||
| @@ -1,12 +1,29 @@ | |||
| import React from 'react' | |||
| 'use client' | |||
| import React, { useEffect, useState } from 'react' | |||
| import cn from 'classnames' | |||
| import Script from 'next/script' | |||
| import Loading from '../components/base/loading' | |||
| import Forms from './forms' | |||
| import Header from './_header' | |||
| import style from './page.module.css' | |||
| import EnterpriseSSOForm from './enterpriseSSOForm' | |||
| import { IS_CE_EDITION } from '@/config' | |||
| import { getEnterpriseFeatures } from '@/service/enterprise' | |||
| import type { EnterpriseFeatures } from '@/types/enterprise' | |||
| import { defaultEnterpriseFeatures } from '@/types/enterprise' | |||
| const SignIn = () => { | |||
| const [loading, setLoading] = useState<boolean>(true) | |||
| const [enterpriseFeatures, setEnterpriseFeatures] = useState<EnterpriseFeatures>(defaultEnterpriseFeatures) | |||
| useEffect(() => { | |||
| getEnterpriseFeatures().then((res) => { | |||
| setEnterpriseFeatures(res) | |||
| }).finally(() => { | |||
| setLoading(false) | |||
| }) | |||
| }, []) | |||
| return ( | |||
| <> | |||
| {!IS_CE_EDITION && ( | |||
| @@ -40,10 +57,31 @@ gtag('config', 'AW-11217955271"'); | |||
| ) | |||
| }> | |||
| <Header /> | |||
| <Forms /> | |||
| <div className='px-8 py-6 text-sm font-normal text-gray-500'> | |||
| © {new Date().getFullYear()} LangGenius, Inc. All rights reserved. | |||
| </div> | |||
| {loading && ( | |||
| <div className={ | |||
| cn( | |||
| 'flex flex-col items-center w-full grow items-center justify-center', | |||
| 'px-6', | |||
| 'md:px-[108px]', | |||
| ) | |||
| }> | |||
| <Loading type='area' /> | |||
| </div> | |||
| )} | |||
| {!loading && !enterpriseFeatures.sso_enforced_for_signin && ( | |||
| <> | |||
| <Forms /> | |||
| <div className='px-8 py-6 text-sm font-normal text-gray-500'> | |||
| © {new Date().getFullYear()} LangGenius, Inc. All rights reserved. | |||
| </div> | |||
| </> | |||
| )} | |||
| {!loading && enterpriseFeatures.sso_enforced_for_signin && ( | |||
| <EnterpriseSSOForm protocol={enterpriseFeatures.sso_enforced_for_signin_protocol} /> | |||
| )} | |||
| </div> | |||
| </div> | |||
| @@ -9,6 +9,7 @@ const translation = { | |||
| namePlaceholder: 'Your username', | |||
| forget: 'Forgot your password?', | |||
| signBtn: 'Sign in', | |||
| sso: 'Continue with SSO', | |||
| installBtn: 'Set up', | |||
| setAdminAccount: 'Setting up an admin account', | |||
| setAdminAccountDesc: 'Maximum privileges for admin account, which can be used to create applications and manage LLM providers, etc.', | |||
| @@ -0,0 +1,14 @@ | |||
| import { get } from './base' | |||
| import type { EnterpriseFeatures } from '@/types/enterprise' | |||
| export const getEnterpriseFeatures = () => { | |||
| return get<EnterpriseFeatures>('/enterprise-features') | |||
| } | |||
| export const getSAMLSSOUrl = () => { | |||
| return get<{ url: string }>('/enterprise/sso/saml/login') | |||
| } | |||
| export const getOIDCSSOUrl = () => { | |||
| return get<{ url: string; state: string }>('/enterprise/sso/oidc/login') | |||
| } | |||
| @@ -0,0 +1,9 @@ | |||
| export type EnterpriseFeatures = { | |||
| sso_enforced_for_signin: boolean | |||
| sso_enforced_for_signin_protocol: string | |||
| } | |||
| export const defaultEnterpriseFeatures: EnterpriseFeatures = { | |||
| sso_enforced_for_signin: false, | |||
| sso_enforced_for_signin_protocol: '', | |||
| } | |||