| @login_manager.request_loader | @login_manager.request_loader | ||||
| def load_user_from_request(request_from_flask_login): | def load_user_from_request(request_from_flask_login): | ||||
| """Load user based on the request.""" | """Load user based on the request.""" | ||||
| if request.blueprint == 'console': | |||||
| if request.blueprint in ['console', 'inner_api']: | |||||
| # Check if the user_id contains a dot, indicating the old format | # Check if the user_id contains a dot, indicating the old format | ||||
| auth_header = request.headers.get('Authorization', '') | auth_header = request.headers.get('Authorization', '') | ||||
| if not auth_header: | if not auth_header: | ||||
| from controllers.files import bp as files_bp | from controllers.files import bp as files_bp | ||||
| from controllers.service_api import bp as service_api_bp | from controllers.service_api import bp as service_api_bp | ||||
| from controllers.web import bp as web_bp | from controllers.web import bp as web_bp | ||||
| from controllers.inner_api import bp as inner_api_bp | |||||
| CORS(service_api_bp, | CORS(service_api_bp, | ||||
| allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], | allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], | ||||
| ) | ) | ||||
| app.register_blueprint(files_bp) | app.register_blueprint(files_bp) | ||||
| app.register_blueprint(inner_api_bp) | |||||
| # create app | # create app | ||||
| app = create_app() | app = create_app() |
| 'TOOL_ICON_CACHE_MAX_AGE': 3600, | 'TOOL_ICON_CACHE_MAX_AGE': 3600, | ||||
| 'MILVUS_DATABASE': 'default', | 'MILVUS_DATABASE': 'default', | ||||
| 'KEYWORD_DATA_SOURCE_TYPE': 'database', | 'KEYWORD_DATA_SOURCE_TYPE': 'database', | ||||
| 'INNER_API': 'False', | |||||
| 'ENTERPRISE_ENABLED': 'False', | |||||
| } | } | ||||
| # Alternatively you can set it with `SECRET_KEY` environment variable. | # Alternatively you can set it with `SECRET_KEY` environment variable. | ||||
| self.SECRET_KEY = get_env('SECRET_KEY') | self.SECRET_KEY = get_env('SECRET_KEY') | ||||
| # Enable or disable the inner API. | |||||
| self.INNER_API = get_bool_env('INNER_API') | |||||
| # The inner API key is used to authenticate the inner API. | |||||
| self.INNER_API_KEY = get_env('INNER_API_KEY') | |||||
| # cors settings | # cors settings | ||||
| self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( | self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( | ||||
| 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL) | 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL) | ||||
| self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') | self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') | ||||
| self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE') | self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE') | ||||
| self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED') | |||||
| class CloudEditionConfig(Config): | class CloudEditionConfig(Config): | ||||
| from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app, | from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app, | ||||
| saved_message, workflow) | saved_message, workflow) | ||||
| # Import workspace controllers | # Import workspace controllers | ||||
| from .workspace import account, members, model_providers, models, tool_providers, workspace | |||||
| from .workspace import account, members, model_providers, models, tool_providers, workspace | |||||
| # Import enterprise controllers | |||||
| from .enterprise import enterprise_sso |
| try: | try: | ||||
| account = AccountService.authenticate(args['email'], args['password']) | account = AccountService.authenticate(args['email'], args['password']) | ||||
| except services.errors.account.AccountLoginError: | |||||
| return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401 | |||||
| except services.errors.account.AccountLoginError as e: | |||||
| return {'code': 'unauthorized', 'message': str(e)}, 401 | |||||
| TenantService.create_owner_tenant_if_not_exist(account) | |||||
| # SELF_HOSTED only have one workspace | |||||
| tenants = TenantService.get_join_tenants(account) | |||||
| if len(tenants) == 0: | |||||
| return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} | |||||
| AccountService.update_last_login(account, request) | AccountService.update_last_login(account, request) | ||||
| from flask import current_app, redirect | |||||
| from flask_restful import Resource, reqparse | |||||
| from controllers.console import api | |||||
| from controllers.console.setup import setup_required | |||||
| from services.enterprise.enterprise_sso_service import EnterpriseSSOService | |||||
| class EnterpriseSSOSamlLogin(Resource): | |||||
| @setup_required | |||||
| def get(self): | |||||
| return EnterpriseSSOService.get_sso_saml_login() | |||||
| class EnterpriseSSOSamlAcs(Resource): | |||||
| @setup_required | |||||
| def post(self): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('SAMLResponse', type=str, required=True, location='form') | |||||
| args = parser.parse_args() | |||||
| saml_response = args['SAMLResponse'] | |||||
| try: | |||||
| token = EnterpriseSSOService.post_sso_saml_acs(saml_response) | |||||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') | |||||
| except Exception as e: | |||||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') | |||||
| class EnterpriseSSOOidcLogin(Resource): | |||||
| @setup_required | |||||
| def get(self): | |||||
| return EnterpriseSSOService.get_sso_oidc_login() | |||||
| class EnterpriseSSOOidcCallback(Resource): | |||||
| @setup_required | |||||
| def get(self): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('state', type=str, required=True, location='args') | |||||
| parser.add_argument('code', type=str, required=True, location='args') | |||||
| parser.add_argument('oidc-state', type=str, required=True, location='cookies') | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| token = EnterpriseSSOService.get_sso_oidc_callback(args) | |||||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') | |||||
| except Exception as e: | |||||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') | |||||
| api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login') | |||||
| api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs') | |||||
| api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login') | |||||
| api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback') |
| from flask_login import current_user | from flask_login import current_user | ||||
| from flask_restful import Resource | from flask_restful import Resource | ||||
| from services.enterprise.enterprise_feature_service import EnterpriseFeatureService | |||||
| from services.feature_service import FeatureService | from services.feature_service import FeatureService | ||||
| from . import api | from . import api | ||||
| return FeatureService.get_features(current_user.current_tenant_id).dict() | return FeatureService.get_features(current_user.current_tenant_id).dict() | ||||
| class EnterpriseFeatureApi(Resource): | |||||
| def get(self): | |||||
| return EnterpriseFeatureService.get_enterprise_features().dict() | |||||
| api.add_resource(FeatureApi, '/features') | api.add_resource(FeatureApi, '/features') | ||||
| api.add_resource(EnterpriseFeatureApi, '/enterprise-features') |
| password=args['password'] | password=args['password'] | ||||
| ) | ) | ||||
| TenantService.create_owner_tenant_if_not_exist(account) | |||||
| setup() | setup() | ||||
| AccountService.update_last_login(account, request) | AccountService.update_last_login(account, request) | ||||
| from flask import request | from flask import request | ||||
| from flask_login import current_user | from flask_login import current_user | ||||
| from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse | from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse | ||||
| from werkzeug.exceptions import Unauthorized | |||||
| import services | import services | ||||
| from controllers.console import api | from controllers.console import api | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.helper import TimestampField | from libs.helper import TimestampField | ||||
| from libs.login import login_required | from libs.login import login_required | ||||
| from models.account import Tenant | |||||
| from models.account import Tenant, TenantStatus | |||||
| from services.account_service import TenantService | from services.account_service import TenantService | ||||
| from services.file_service import FileService | from services.file_service import FileService | ||||
| from services.workspace_service import WorkspaceService | from services.workspace_service import WorkspaceService | ||||
| tenant = current_user.current_tenant | tenant = current_user.current_tenant | ||||
| if tenant.status == TenantStatus.ARCHIVE: | |||||
| tenants = TenantService.get_join_tenants(current_user) | |||||
| # if there is any tenant, switch to the first one | |||||
| if len(tenants) > 0: | |||||
| TenantService.switch_tenant(current_user, tenants[0].id) | |||||
| tenant = tenants[0] | |||||
| # else, raise Unauthorized | |||||
| else: | |||||
| raise Unauthorized('workspace is archived') | |||||
| return WorkspaceService.get_tenant_info(tenant), 200 | return WorkspaceService.get_tenant_info(tenant), 200 | ||||
| from flask import Blueprint | |||||
| from libs.external_api import ExternalApi | |||||
| bp = Blueprint('inner_api', __name__, url_prefix='/inner/api') | |||||
| api = ExternalApi(bp) | |||||
| from .workspace import workspace | |||||
| from flask_restful import Resource, reqparse | |||||
| from controllers.console.setup import setup_required | |||||
| from controllers.inner_api import api | |||||
| from controllers.inner_api.wraps import inner_api_only | |||||
| from events.tenant_event import tenant_was_created | |||||
| from models.account import Account | |||||
| from services.account_service import TenantService | |||||
| class EnterpriseWorkspace(Resource): | |||||
| @setup_required | |||||
| @inner_api_only | |||||
| def post(self): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('name', type=str, required=True, location='json') | |||||
| parser.add_argument('owner_email', type=str, required=True, location='json') | |||||
| args = parser.parse_args() | |||||
| account = Account.query.filter_by(email=args['owner_email']).first() | |||||
| if account is None: | |||||
| return { | |||||
| 'message': 'owner account not found.' | |||||
| }, 404 | |||||
| tenant = TenantService.create_tenant(args['name']) | |||||
| TenantService.create_tenant_member(tenant, account, role='owner') | |||||
| tenant_was_created.send(tenant) | |||||
| return { | |||||
| 'message': 'enterprise workspace created.' | |||||
| } | |||||
| api.add_resource(EnterpriseWorkspace, '/enterprise/workspace') |
| from base64 import b64encode | |||||
| from functools import wraps | |||||
| from hashlib import sha1 | |||||
| from hmac import new as hmac_new | |||||
| from flask import abort, current_app, request | |||||
| from extensions.ext_database import db | |||||
| from models.model import EndUser | |||||
| def inner_api_only(view): | |||||
| @wraps(view) | |||||
| def decorated(*args, **kwargs): | |||||
| if not current_app.config['INNER_API']: | |||||
| abort(404) | |||||
| # get header 'X-Inner-Api-Key' | |||||
| inner_api_key = request.headers.get('X-Inner-Api-Key') | |||||
| if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']: | |||||
| abort(404) | |||||
| return view(*args, **kwargs) | |||||
| return decorated | |||||
| def inner_api_user_auth(view): | |||||
| @wraps(view) | |||||
| def decorated(*args, **kwargs): | |||||
| if not current_app.config['INNER_API']: | |||||
| return view(*args, **kwargs) | |||||
| # get header 'X-Inner-Api-Key' | |||||
| authorization = request.headers.get('Authorization') | |||||
| if not authorization: | |||||
| return view(*args, **kwargs) | |||||
| parts = authorization.split(':') | |||||
| if len(parts) != 2: | |||||
| return view(*args, **kwargs) | |||||
| user_id, token = parts | |||||
| if ' ' in user_id: | |||||
| user_id = user_id.split(' ')[1] | |||||
| inner_api_key = request.headers.get('X-Inner-Api-Key') | |||||
| data_to_sign = f'DIFY {user_id}' | |||||
| signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1) | |||||
| signature = b64encode(signature.digest()).decode('utf-8') | |||||
| if signature != token: | |||||
| return view(*args, **kwargs) | |||||
| kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first() | |||||
| return view(*args, **kwargs) | |||||
| return decorated |
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.login import _get_user | from libs.login import _get_user | ||||
| from models.account import Account, Tenant, TenantAccountJoin | |||||
| from models.account import Account, Tenant, TenantAccountJoin, TenantStatus | |||||
| from models.model import ApiToken, App, EndUser | from models.model import ApiToken, App, EndUser | ||||
| from services.feature_service import FeatureService | from services.feature_service import FeatureService | ||||
| if not app_model.enable_api: | if not app_model.enable_api: | ||||
| raise NotFound() | raise NotFound() | ||||
| tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() | |||||
| if tenant.status == TenantStatus.ARCHIVE: | |||||
| raise NotFound() | |||||
| kwargs['app_model'] = app_model | kwargs['app_model'] = app_model | ||||
| if fetch_user_arg: | if fetch_user_arg: | ||||
| .filter(Tenant.id == api_token.tenant_id) \ | .filter(Tenant.id == api_token.tenant_id) \ | ||||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) \ | .filter(TenantAccountJoin.tenant_id == Tenant.id) \ | ||||
| .filter(TenantAccountJoin.role.in_(['owner'])) \ | .filter(TenantAccountJoin.role.in_(['owner'])) \ | ||||
| .filter(Tenant.status == TenantStatus.NORMAL) \ | |||||
| .one_or_none() # TODO: only owner information is required, so only one is returned. | .one_or_none() # TODO: only owner information is required, so only one is returned. | ||||
| if tenant_account_join: | if tenant_account_join: | ||||
| tenant, ta = tenant_account_join | tenant, ta = tenant_account_join |
| from controllers.web import api | from controllers.web import api | ||||
| from controllers.web.wraps import WebApiResource | from controllers.web.wraps import WebApiResource | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.account import TenantStatus | |||||
| from models.model import Site | from models.model import Site | ||||
| from services.feature_service import FeatureService | from services.feature_service import FeatureService | ||||
| if not site: | if not site: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| if app_model.tenant.status == TenantStatus.ARCHIVE: | |||||
| raise Forbidden() | |||||
| can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo | can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo | ||||
| return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) | return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) |
| def is_admin_or_owner(self): | def is_admin_or_owner(self): | ||||
| return self._current_tenant.current_role in ['admin', 'owner'] | return self._current_tenant.current_role in ['admin', 'owner'] | ||||
| class TenantStatus(str, enum.Enum): | |||||
| NORMAL = 'normal' | |||||
| ARCHIVE = 'archive' | |||||
| class Tenant(db.Model): | class Tenant(db.Model): | ||||
| __tablename__ = 'tenants' | __tablename__ = 'tenants' | ||||
| __table_args__ = ( | __table_args__ = ( |
| from flask import current_app | from flask import current_app | ||||
| from sqlalchemy import func | from sqlalchemy import func | ||||
| from werkzeug.exceptions import Forbidden | |||||
| from werkzeug.exceptions import Unauthorized | |||||
| from constants.languages import language_timezone_mapping, languages | from constants.languages import language_timezone_mapping, languages | ||||
| from events.tenant_event import tenant_was_created | from events.tenant_event import tenant_was_created | ||||
| return None | return None | ||||
| if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: | if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: | ||||
| raise Forbidden('Account is banned or closed.') | |||||
| raise Unauthorized("Account is banned or closed.") | |||||
| current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() | current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() | ||||
| if current_tenant: | if current_tenant: | ||||
| """Get account join tenants""" | """Get account join tenants""" | ||||
| return db.session.query(Tenant).join( | return db.session.query(Tenant).join( | ||||
| TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id | TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id | ||||
| ).filter(TenantAccountJoin.account_id == account.id).all() | |||||
| ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all() | |||||
| @staticmethod | @staticmethod | ||||
| def get_current_tenant_by_account(account: Account): | def get_current_tenant_by_account(account: Account): | ||||
| if tenant_id is None: | if tenant_id is None: | ||||
| raise ValueError("Tenant ID must be provided.") | raise ValueError("Tenant ID must be provided.") | ||||
| tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first() | |||||
| tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter( | |||||
| TenantAccountJoin.account_id == account.id, | |||||
| TenantAccountJoin.tenant_id == tenant_id, | |||||
| Tenant.status == TenantStatus.NORMAL, | |||||
| ).first() | |||||
| if not tenant_account_join: | if not tenant_account_join: | ||||
| raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") | raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") | ||||
| else: | else: |
| import os | |||||
| import requests | |||||
| class EnterpriseRequest: | |||||
| base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL') | |||||
| secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY') | |||||
| @classmethod | |||||
| def send_request(cls, method, endpoint, json=None, params=None): | |||||
| headers = { | |||||
| "Content-Type": "application/json", | |||||
| "Enterprise-Api-Secret-Key": cls.secret_key | |||||
| } | |||||
| url = f"{cls.base_url}{endpoint}" | |||||
| response = requests.request(method, url, json=json, params=params, headers=headers) | |||||
| return response.json() |
| from flask import current_app | |||||
| from pydantic import BaseModel | |||||
| from services.enterprise.enterprise_service import EnterpriseService | |||||
| class EnterpriseFeatureModel(BaseModel): | |||||
| sso_enforced_for_signin: bool = False | |||||
| sso_enforced_for_signin_protocol: str = '' | |||||
| class EnterpriseFeatureService: | |||||
| @classmethod | |||||
| def get_enterprise_features(cls) -> EnterpriseFeatureModel: | |||||
| features = EnterpriseFeatureModel() | |||||
| if current_app.config['ENTERPRISE_ENABLED']: | |||||
| cls._fulfill_params_from_enterprise(features) | |||||
| return features | |||||
| @classmethod | |||||
| def _fulfill_params_from_enterprise(cls, features): | |||||
| enterprise_info = EnterpriseService.get_info() | |||||
| features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] | |||||
| features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] |
| from services.enterprise.base import EnterpriseRequest | |||||
| class EnterpriseService: | |||||
| @classmethod | |||||
| def get_info(cls): | |||||
| return EnterpriseRequest.send_request('GET', '/info') |
| import logging | |||||
| from models.account import Account, AccountStatus | |||||
| from services.account_service import AccountService, TenantService | |||||
| from services.enterprise.base import EnterpriseRequest | |||||
| logger = logging.getLogger(__name__) | |||||
| class EnterpriseSSOService: | |||||
| @classmethod | |||||
| def get_sso_saml_login(cls) -> str: | |||||
| return EnterpriseRequest.send_request('GET', '/sso/saml/login') | |||||
| @classmethod | |||||
| def post_sso_saml_acs(cls, saml_response: str) -> str: | |||||
| response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response}) | |||||
| if 'email' not in response or response['email'] is None: | |||||
| logger.exception(response) | |||||
| raise Exception('Saml response is invalid') | |||||
| return cls.login_with_email(response.get('email')) | |||||
| @classmethod | |||||
| def get_sso_oidc_login(cls): | |||||
| return EnterpriseRequest.send_request('GET', '/sso/oidc/login') | |||||
| @classmethod | |||||
| def get_sso_oidc_callback(cls, args: dict): | |||||
| state_from_query = args['state'] | |||||
| code_from_query = args['code'] | |||||
| state_from_cookies = args['oidc-state'] | |||||
| if state_from_cookies != state_from_query: | |||||
| raise Exception('invalid state or code') | |||||
| response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query}) | |||||
| if 'email' not in response or response['email'] is None: | |||||
| logger.exception(response) | |||||
| raise Exception('OIDC response is invalid') | |||||
| return cls.login_with_email(response.get('email')) | |||||
| @classmethod | |||||
| def login_with_email(cls, email: str) -> str: | |||||
| account = Account.query.filter_by(email=email).first() | |||||
| if account is None: | |||||
| raise Exception('account not found, please contact system admin to invite you to join in a workspace') | |||||
| if account.status == AccountStatus.BANNED: | |||||
| raise Exception('account is banned, please contact system admin') | |||||
| tenants = TenantService.get_join_tenants(account) | |||||
| if len(tenants) == 0: | |||||
| raise Exception("workspace not found, please contact system admin to invite you to join in a workspace") | |||||
| token = AccountService.get_account_jwt_token(account) | |||||
| return token |
| url: '/logout', | url: '/logout', | ||||
| params: {}, | params: {}, | ||||
| }) | }) | ||||
| if (localStorage?.getItem('console_token')) | |||||
| localStorage.removeItem('console_token') | |||||
| router.push('/signin') | router.push('/signin') | ||||
| } | } | ||||
| const Header = () => { | const Header = () => { | ||||
| const { locale, setLocaleOnClient } = useContext(I18n) | const { locale, setLocaleOnClient } = useContext(I18n) | ||||
| if (localStorage?.getItem('console_token')) | |||||
| localStorage.removeItem('console_token') | |||||
| return <div className='flex items-center justify-between p-6 w-full'> | return <div className='flex items-center justify-between p-6 w-full'> | ||||
| <LogoSite /> | <LogoSite /> | ||||
| <Select | <Select |
| 'use client' | |||||
| import cn from 'classnames' | |||||
| import { useRouter, useSearchParams } from 'next/navigation' | |||||
| import type { FC } from 'react' | |||||
| import { useEffect, useState } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import Toast from '@/app/components/base/toast' | |||||
| import { getOIDCSSOUrl, getSAMLSSOUrl } from '@/service/enterprise' | |||||
| import Button from '@/app/components/base/button' | |||||
| type EnterpriseSSOFormProps = { | |||||
| protocol: string | |||||
| } | |||||
| const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({ | |||||
| protocol, | |||||
| }) => { | |||||
| const searchParams = useSearchParams() | |||||
| const consoleToken = searchParams.get('console_token') | |||||
| const message = searchParams.get('message') | |||||
| const router = useRouter() | |||||
| const { t } = useTranslation() | |||||
| const [isLoading, setIsLoading] = useState(false) | |||||
| useEffect(() => { | |||||
| if (consoleToken) { | |||||
| localStorage.setItem('console_token', consoleToken) | |||||
| router.replace('/apps') | |||||
| } | |||||
| if (message) { | |||||
| Toast.notify({ | |||||
| type: 'error', | |||||
| message, | |||||
| }) | |||||
| } | |||||
| }, []) | |||||
| const handleSSOLogin = () => { | |||||
| setIsLoading(true) | |||||
| if (protocol === 'saml') { | |||||
| getSAMLSSOUrl().then((res) => { | |||||
| router.push(res.url) | |||||
| }).finally(() => { | |||||
| setIsLoading(false) | |||||
| }) | |||||
| } | |||||
| else { | |||||
| getOIDCSSOUrl().then((res) => { | |||||
| document.cookie = `oidc-state=${res.state}` | |||||
| router.push(res.url) | |||||
| }).finally(() => { | |||||
| setIsLoading(false) | |||||
| }) | |||||
| } | |||||
| } | |||||
| return ( | |||||
| <div className={ | |||||
| cn( | |||||
| 'flex flex-col items-center w-full grow items-center justify-center', | |||||
| 'px-6', | |||||
| 'md:px-[108px]', | |||||
| ) | |||||
| }> | |||||
| <div className='flex flex-col md:w-[400px]'> | |||||
| <div className="w-full mx-auto"> | |||||
| <h2 className="text-[32px] font-bold text-gray-900">{t('login.pageTitle')}</h2> | |||||
| </div> | |||||
| <div className="w-full mx-auto mt-10"> | |||||
| <Button | |||||
| tabIndex={0} | |||||
| type='primary' | |||||
| onClick={() => { handleSSOLogin() }} | |||||
| disabled={isLoading} | |||||
| className="w-full !fone-medium !text-sm" | |||||
| >{t('login.sso')} | |||||
| </Button> | |||||
| </div> | |||||
| </div> | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| export default EnterpriseSSOForm |
| remember_me: true, | remember_me: true, | ||||
| }, | }, | ||||
| }) | }) | ||||
| localStorage.setItem('console_token', res.data) | |||||
| router.replace('/apps') | |||||
| if (res.result === 'success') { | |||||
| localStorage.setItem('console_token', res.data) | |||||
| router.replace('/apps') | |||||
| } | |||||
| else { | |||||
| Toast.notify({ | |||||
| type: 'error', | |||||
| message: res.data, | |||||
| }) | |||||
| } | |||||
| } | } | ||||
| finally { | finally { | ||||
| setIsLoading(false) | setIsLoading(false) |
| import React from 'react' | |||||
| 'use client' | |||||
| import React, { useEffect, useState } from 'react' | |||||
| import cn from 'classnames' | import cn from 'classnames' | ||||
| import Script from 'next/script' | import Script from 'next/script' | ||||
| import Loading from '../components/base/loading' | |||||
| import Forms from './forms' | import Forms from './forms' | ||||
| import Header from './_header' | import Header from './_header' | ||||
| import style from './page.module.css' | import style from './page.module.css' | ||||
| import EnterpriseSSOForm from './enterpriseSSOForm' | |||||
| import { IS_CE_EDITION } from '@/config' | import { IS_CE_EDITION } from '@/config' | ||||
| import { getEnterpriseFeatures } from '@/service/enterprise' | |||||
| import type { EnterpriseFeatures } from '@/types/enterprise' | |||||
| import { defaultEnterpriseFeatures } from '@/types/enterprise' | |||||
| const SignIn = () => { | const SignIn = () => { | ||||
| const [loading, setLoading] = useState<boolean>(true) | |||||
| const [enterpriseFeatures, setEnterpriseFeatures] = useState<EnterpriseFeatures>(defaultEnterpriseFeatures) | |||||
| useEffect(() => { | |||||
| getEnterpriseFeatures().then((res) => { | |||||
| setEnterpriseFeatures(res) | |||||
| }).finally(() => { | |||||
| setLoading(false) | |||||
| }) | |||||
| }, []) | |||||
| return ( | return ( | ||||
| <> | <> | ||||
| {!IS_CE_EDITION && ( | {!IS_CE_EDITION && ( | ||||
| ) | ) | ||||
| }> | }> | ||||
| <Header /> | <Header /> | ||||
| <Forms /> | |||||
| <div className='px-8 py-6 text-sm font-normal text-gray-500'> | |||||
| © {new Date().getFullYear()} LangGenius, Inc. All rights reserved. | |||||
| </div> | |||||
| {loading && ( | |||||
| <div className={ | |||||
| cn( | |||||
| 'flex flex-col items-center w-full grow items-center justify-center', | |||||
| 'px-6', | |||||
| 'md:px-[108px]', | |||||
| ) | |||||
| }> | |||||
| <Loading type='area' /> | |||||
| </div> | |||||
| )} | |||||
| {!loading && !enterpriseFeatures.sso_enforced_for_signin && ( | |||||
| <> | |||||
| <Forms /> | |||||
| <div className='px-8 py-6 text-sm font-normal text-gray-500'> | |||||
| © {new Date().getFullYear()} LangGenius, Inc. All rights reserved. | |||||
| </div> | |||||
| </> | |||||
| )} | |||||
| {!loading && enterpriseFeatures.sso_enforced_for_signin && ( | |||||
| <EnterpriseSSOForm protocol={enterpriseFeatures.sso_enforced_for_signin_protocol} /> | |||||
| )} | |||||
| </div> | </div> | ||||
| </div> | </div> |
| namePlaceholder: 'Your username', | namePlaceholder: 'Your username', | ||||
| forget: 'Forgot your password?', | forget: 'Forgot your password?', | ||||
| signBtn: 'Sign in', | signBtn: 'Sign in', | ||||
| sso: 'Continue with SSO', | |||||
| installBtn: 'Set up', | installBtn: 'Set up', | ||||
| setAdminAccount: 'Setting up an admin account', | setAdminAccount: 'Setting up an admin account', | ||||
| setAdminAccountDesc: 'Maximum privileges for admin account, which can be used to create applications and manage LLM providers, etc.', | setAdminAccountDesc: 'Maximum privileges for admin account, which can be used to create applications and manage LLM providers, etc.', |
| import { get } from './base' | |||||
| import type { EnterpriseFeatures } from '@/types/enterprise' | |||||
| export const getEnterpriseFeatures = () => { | |||||
| return get<EnterpriseFeatures>('/enterprise-features') | |||||
| } | |||||
| export const getSAMLSSOUrl = () => { | |||||
| return get<{ url: string }>('/enterprise/sso/saml/login') | |||||
| } | |||||
| export const getOIDCSSOUrl = () => { | |||||
| return get<{ url: string; state: string }>('/enterprise/sso/oidc/login') | |||||
| } |
| export type EnterpriseFeatures = { | |||||
| sso_enforced_for_signin: boolean | |||||
| sso_enforced_for_signin_protocol: string | |||||
| } | |||||
| export const defaultEnterpriseFeatures: EnterpriseFeatures = { | |||||
| sso_enforced_for_signin: false, | |||||
| sso_enforced_for_signin_protocol: '', | |||||
| } |