| @@ -37,9 +37,6 @@ from .billing import billing | |||
| # Import datasets controllers | |||
| from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing | |||
| # Import enterprise controllers | |||
| from .enterprise import enterprise_sso | |||
| # Import explore controllers | |||
| from .explore import ( | |||
| audio, | |||
| @@ -1,59 +0,0 @@ | |||
| from flask import current_app, redirect | |||
| from flask_restful import Resource, reqparse | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from services.enterprise.enterprise_sso_service import EnterpriseSSOService | |||
| class EnterpriseSSOSamlLogin(Resource): | |||
| @setup_required | |||
| def get(self): | |||
| return EnterpriseSSOService.get_sso_saml_login() | |||
| class EnterpriseSSOSamlAcs(Resource): | |||
| @setup_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('SAMLResponse', type=str, required=True, location='form') | |||
| args = parser.parse_args() | |||
| saml_response = args['SAMLResponse'] | |||
| try: | |||
| token = EnterpriseSSOService.post_sso_saml_acs(saml_response) | |||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') | |||
| except Exception as e: | |||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') | |||
| class EnterpriseSSOOidcLogin(Resource): | |||
| @setup_required | |||
| def get(self): | |||
| return EnterpriseSSOService.get_sso_oidc_login() | |||
| class EnterpriseSSOOidcCallback(Resource): | |||
| @setup_required | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('state', type=str, required=True, location='args') | |||
| parser.add_argument('code', type=str, required=True, location='args') | |||
| parser.add_argument('oidc-state', type=str, required=True, location='cookies') | |||
| args = parser.parse_args() | |||
| try: | |||
| token = EnterpriseSSOService.get_sso_oidc_callback(args) | |||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') | |||
| except Exception as e: | |||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') | |||
| api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login') | |||
| api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs') | |||
| api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login') | |||
| api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback') | |||
| @@ -1,7 +1,6 @@ | |||
| from flask_login import current_user | |||
| from flask_restful import Resource | |||
| from services.enterprise.enterprise_feature_service import EnterpriseFeatureService | |||
| from services.feature_service import FeatureService | |||
| from . import api | |||
| @@ -15,10 +14,10 @@ class FeatureApi(Resource): | |||
| return FeatureService.get_features(current_user.current_tenant_id).dict() | |||
| class EnterpriseFeatureApi(Resource): | |||
| class SystemFeatureApi(Resource): | |||
| def get(self): | |||
| return EnterpriseFeatureService.get_enterprise_features().dict() | |||
| return FeatureService.get_system_features().dict() | |||
| api.add_resource(FeatureApi, '/features') | |||
| api.add_resource(EnterpriseFeatureApi, '/enterprise-features') | |||
| api.add_resource(SystemFeatureApi, '/system-features') | |||
| @@ -6,4 +6,4 @@ bp = Blueprint('web', __name__, url_prefix='/api') | |||
| api = ExternalApi(bp) | |||
| from . import app, audio, completion, conversation, file, message, passport, saved_message, site, workflow | |||
| from . import app, audio, completion, conversation, feature, file, message, passport, saved_message, site, workflow | |||
| @@ -1,14 +1,10 @@ | |||
| import json | |||
| from flask import current_app | |||
| from flask_restful import fields, marshal_with | |||
| from controllers.web import api | |||
| from controllers.web.error import AppUnavailableError | |||
| from controllers.web.wraps import WebApiResource | |||
| from extensions.ext_database import db | |||
| from models.model import App, AppMode, AppModelConfig | |||
| from models.tools import ApiToolProvider | |||
| from models.model import App, AppMode | |||
| from services.app_service import AppService | |||
| @@ -115,3 +115,9 @@ class UnsupportedFileTypeError(BaseHTTPException): | |||
| error_code = 'unsupported_file_type' | |||
| description = "File type not allowed." | |||
| code = 415 | |||
| class WebSSOAuthRequiredError(BaseHTTPException): | |||
| error_code = 'web_sso_auth_required' | |||
| description = "Web SSO authentication required." | |||
| code = 401 | |||
| @@ -0,0 +1,12 @@ | |||
| from flask_restful import Resource | |||
| from controllers.web import api | |||
| from services.feature_service import FeatureService | |||
| class SystemFeatureApi(Resource): | |||
| def get(self): | |||
| return FeatureService.get_system_features().dict() | |||
| api.add_resource(SystemFeatureApi, '/system-features') | |||
| @@ -5,14 +5,21 @@ from flask_restful import Resource | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| from controllers.web import api | |||
| from controllers.web.error import WebSSOAuthRequiredError | |||
| from extensions.ext_database import db | |||
| from libs.passport import PassportService | |||
| from models.model import App, EndUser, Site | |||
| from services.feature_service import FeatureService | |||
| class PassportResource(Resource): | |||
| """Base resource for passport.""" | |||
| def get(self): | |||
| system_features = FeatureService.get_system_features() | |||
| if system_features.sso_enforced_for_web: | |||
| raise WebSSOAuthRequiredError() | |||
| app_code = request.headers.get('X-App-Code') | |||
| if app_code is None: | |||
| raise Unauthorized('X-App-Code header is missing.') | |||
| @@ -28,7 +35,7 @@ class PassportResource(Resource): | |||
| app_model = db.session.query(App).filter(App.id == site.app_id).first() | |||
| if not app_model or app_model.status != 'normal' or not app_model.enable_site: | |||
| raise NotFound() | |||
| end_user = EndUser( | |||
| tenant_id=app_model.tenant_id, | |||
| app_id=app_model.id, | |||
| @@ -36,6 +43,7 @@ class PassportResource(Resource): | |||
| is_anonymous=True, | |||
| session_id=generate_session_id(), | |||
| ) | |||
| db.session.add(end_user) | |||
| db.session.commit() | |||
| @@ -53,8 +61,10 @@ class PassportResource(Resource): | |||
| 'access_token': tk, | |||
| } | |||
| api.add_resource(PassportResource, '/passport') | |||
| def generate_session_id(): | |||
| """ | |||
| Generate a unique session ID. | |||
| @@ -2,11 +2,13 @@ from functools import wraps | |||
| from flask import request | |||
| from flask_restful import Resource | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| from werkzeug.exceptions import BadRequest, NotFound, Unauthorized | |||
| from controllers.web.error import WebSSOAuthRequiredError | |||
| from extensions.ext_database import db | |||
| from libs.passport import PassportService | |||
| from models.model import App, EndUser, Site | |||
| from services.feature_service import FeatureService | |||
| def validate_jwt_token(view=None): | |||
| @@ -21,34 +23,60 @@ def validate_jwt_token(view=None): | |||
| return decorator(view) | |||
| return decorator | |||
| def decode_jwt_token(): | |||
| auth_header = request.headers.get('Authorization') | |||
| if auth_header is None: | |||
| raise Unauthorized('Authorization header is missing.') | |||
| if ' ' not in auth_header: | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| auth_scheme, tk = auth_header.split(None, 1) | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != 'bearer': | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| decoded = PassportService().verify(tk) | |||
| app_code = decoded.get('app_code') | |||
| app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() | |||
| site = db.session.query(Site).filter(Site.code == app_code).first() | |||
| if not app_model: | |||
| raise NotFound() | |||
| if not app_code or not site: | |||
| raise Unauthorized('Site URL is no longer valid.') | |||
| if app_model.enable_site is False: | |||
| raise Unauthorized('Site is disabled.') | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() | |||
| if not end_user: | |||
| raise NotFound() | |||
| return app_model, end_user | |||
| system_features = FeatureService.get_system_features() | |||
| try: | |||
| auth_header = request.headers.get('Authorization') | |||
| if auth_header is None: | |||
| raise Unauthorized('Authorization header is missing.') | |||
| if ' ' not in auth_header: | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| auth_scheme, tk = auth_header.split(None, 1) | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != 'bearer': | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| decoded = PassportService().verify(tk) | |||
| app_code = decoded.get('app_code') | |||
| app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() | |||
| site = db.session.query(Site).filter(Site.code == app_code).first() | |||
| if not app_model: | |||
| raise NotFound() | |||
| if not app_code or not site: | |||
| raise BadRequest('Site URL is no longer valid.') | |||
| if app_model.enable_site is False: | |||
| raise BadRequest('Site is disabled.') | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() | |||
| if not end_user: | |||
| raise NotFound() | |||
| _validate_web_sso_token(decoded, system_features) | |||
| return app_model, end_user | |||
| except Unauthorized as e: | |||
| if system_features.sso_enforced_for_web: | |||
| raise WebSSOAuthRequiredError() | |||
| raise Unauthorized(e.description) | |||
| def _validate_web_sso_token(decoded, system_features): | |||
| # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login | |||
| if system_features.sso_enforced_for_web: | |||
| source = decoded.get('token_source') | |||
| if not source or source != 'sso': | |||
| raise WebSSOAuthRequiredError() | |||
| # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login | |||
| if not system_features.sso_enforced_for_web: | |||
| source = decoded.get('token_source') | |||
| if source and source == 'sso': | |||
| raise Unauthorized('sso token expired.') | |||
| class WebApiResource(Resource): | |||
| method_decorators = [validate_jwt_token] | |||
| @@ -1,28 +0,0 @@ | |||
| from flask import current_app | |||
| from pydantic import BaseModel | |||
| from services.enterprise.enterprise_service import EnterpriseService | |||
| class EnterpriseFeatureModel(BaseModel): | |||
| sso_enforced_for_signin: bool = False | |||
| sso_enforced_for_signin_protocol: str = '' | |||
| class EnterpriseFeatureService: | |||
| @classmethod | |||
| def get_enterprise_features(cls) -> EnterpriseFeatureModel: | |||
| features = EnterpriseFeatureModel() | |||
| if current_app.config['ENTERPRISE_ENABLED']: | |||
| cls._fulfill_params_from_enterprise(features) | |||
| return features | |||
| @classmethod | |||
| def _fulfill_params_from_enterprise(cls, features): | |||
| enterprise_info = EnterpriseService.get_info() | |||
| features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] | |||
| features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] | |||
| @@ -1,60 +0,0 @@ | |||
| import logging | |||
| from models.account import Account, AccountStatus | |||
| from services.account_service import AccountService, TenantService | |||
| from services.enterprise.base import EnterpriseRequest | |||
| logger = logging.getLogger(__name__) | |||
| class EnterpriseSSOService: | |||
| @classmethod | |||
| def get_sso_saml_login(cls) -> str: | |||
| return EnterpriseRequest.send_request('GET', '/sso/saml/login') | |||
| @classmethod | |||
| def post_sso_saml_acs(cls, saml_response: str) -> str: | |||
| response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response}) | |||
| if 'email' not in response or response['email'] is None: | |||
| logger.exception(response) | |||
| raise Exception('Saml response is invalid') | |||
| return cls.login_with_email(response.get('email')) | |||
| @classmethod | |||
| def get_sso_oidc_login(cls): | |||
| return EnterpriseRequest.send_request('GET', '/sso/oidc/login') | |||
| @classmethod | |||
| def get_sso_oidc_callback(cls, args: dict): | |||
| state_from_query = args['state'] | |||
| code_from_query = args['code'] | |||
| state_from_cookies = args['oidc-state'] | |||
| if state_from_cookies != state_from_query: | |||
| raise Exception('invalid state or code') | |||
| response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query}) | |||
| if 'email' not in response or response['email'] is None: | |||
| logger.exception(response) | |||
| raise Exception('OIDC response is invalid') | |||
| return cls.login_with_email(response.get('email')) | |||
| @classmethod | |||
| def login_with_email(cls, email: str) -> str: | |||
| account = Account.query.filter_by(email=email).first() | |||
| if account is None: | |||
| raise Exception('account not found, please contact system admin to invite you to join in a workspace') | |||
| if account.status == AccountStatus.BANNED: | |||
| raise Exception('account is banned, please contact system admin') | |||
| tenants = TenantService.get_join_tenants(account) | |||
| if len(tenants) == 0: | |||
| raise Exception("workspace not found, please contact system admin to invite you to join in a workspace") | |||
| token = AccountService.get_account_jwt_token(account) | |||
| return token | |||
| @@ -2,6 +2,7 @@ from flask import current_app | |||
| from pydantic import BaseModel | |||
| from services.billing_service import BillingService | |||
| from services.enterprise.enterprise_service import EnterpriseService | |||
| class SubscriptionModel(BaseModel): | |||
| @@ -30,6 +31,13 @@ class FeatureModel(BaseModel): | |||
| can_replace_logo: bool = False | |||
| class SystemFeatureModel(BaseModel): | |||
| sso_enforced_for_signin: bool = False | |||
| sso_enforced_for_signin_protocol: str = '' | |||
| sso_enforced_for_web: bool = False | |||
| sso_enforced_for_web_protocol: str = '' | |||
| class FeatureService: | |||
| @classmethod | |||
| @@ -43,6 +51,15 @@ class FeatureService: | |||
| return features | |||
| @classmethod | |||
| def get_system_features(cls) -> SystemFeatureModel: | |||
| system_features = SystemFeatureModel() | |||
| if current_app.config['ENTERPRISE_ENABLED']: | |||
| cls._fulfill_params_from_enterprise(system_features) | |||
| return system_features | |||
| @classmethod | |||
| def _fulfill_params_from_env(cls, features: FeatureModel): | |||
| features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO'] | |||
| @@ -73,3 +90,11 @@ class FeatureService: | |||
| features.docs_processing = billing_info['docs_processing'] | |||
| features.can_replace_logo = billing_info['can_replace_logo'] | |||
| @classmethod | |||
| def _fulfill_params_from_enterprise(cls, features): | |||
| enterprise_info = EnterpriseService.get_info() | |||
| features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] | |||
| features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] | |||
| features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web'] | |||
| features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol'] | |||
| @@ -1,5 +1,4 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React from 'react' | |||
| @@ -1,12 +1,87 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React from 'react' | |||
| import React, { useEffect } from 'react' | |||
| import cn from 'classnames' | |||
| import type { IMainProps } from '@/app/components/share/chat' | |||
| import Main from '@/app/components/share/chatbot' | |||
| import Loading from '@/app/components/base/loading' | |||
| import { fetchSystemFeatures } from '@/service/share' | |||
| import LogoSite from '@/app/components/base/logo/logo-site' | |||
| const Chatbot: FC<IMainProps> = () => { | |||
| const [isSSOEnforced, setIsSSOEnforced] = React.useState(true) | |||
| const [loading, setLoading] = React.useState(true) | |||
| useEffect(() => { | |||
| fetchSystemFeatures().then((res) => { | |||
| setIsSSOEnforced(res.sso_enforced_for_web) | |||
| setLoading(false) | |||
| }) | |||
| }, []) | |||
| return ( | |||
| <Main /> | |||
| <> | |||
| { | |||
| loading | |||
| ? ( | |||
| <div className="flex items-center justify-center h-full" > | |||
| <div className={ | |||
| cn( | |||
| 'flex flex-col items-center w-full grow items-center justify-center', | |||
| 'px-6', | |||
| 'md:px-[108px]', | |||
| ) | |||
| }> | |||
| <Loading type='area' /> | |||
| </div> | |||
| </div > | |||
| ) | |||
| : ( | |||
| <> | |||
| {isSSOEnforced | |||
| ? ( | |||
| <div className={cn( | |||
| 'flex w-full min-h-screen', | |||
| 'sm:p-4 lg:p-8', | |||
| 'gap-x-20', | |||
| 'justify-center lg:justify-start', | |||
| )}> | |||
| <div className={ | |||
| cn( | |||
| 'flex w-full flex-col bg-white shadow rounded-2xl shrink-0', | |||
| 'space-between', | |||
| ) | |||
| }> | |||
| <div className='flex items-center justify-between p-6 w-full'> | |||
| <LogoSite /> | |||
| </div> | |||
| <div className={ | |||
| cn( | |||
| 'flex flex-col items-center w-full grow items-center justify-center', | |||
| 'px-6', | |||
| 'md:px-[108px]', | |||
| ) | |||
| }> | |||
| <div className='flex flex-col md:w-[400px]'> | |||
| <div className="w-full mx-auto"> | |||
| <h2 className="text-[16px] font-bold text-gray-900"> | |||
| Warning: Chatbot is not available | |||
| </h2> | |||
| <p className="text-[16px] text-gray-600 mt-2"> | |||
| Because SSO is enforced. Please contact your administrator. | |||
| </p> | |||
| </div> | |||
| </div> | |||
| </div> | |||
| </div> | |||
| </div> | |||
| ) | |||
| : <Main /> | |||
| } | |||
| </> | |||
| )} | |||
| </> | |||
| ) | |||
| } | |||
| @@ -0,0 +1,147 @@ | |||
| 'use client' | |||
| import cn from 'classnames' | |||
| import { useRouter, useSearchParams } from 'next/navigation' | |||
| import type { FC } from 'react' | |||
| import React, { useEffect, useState } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import Toast from '@/app/components/base/toast' | |||
| import Button from '@/app/components/base/button' | |||
| import { fetchSystemFeatures, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share' | |||
| import LogoSite from '@/app/components/base/logo/logo-site' | |||
| import { setAccessToken } from '@/app/components/share/utils' | |||
| const WebSSOForm: FC = () => { | |||
| const searchParams = useSearchParams() | |||
| const redirectUrl = searchParams.get('redirect_url') | |||
| const tokenFromUrl = searchParams.get('web_sso_token') | |||
| const message = searchParams.get('message') | |||
| const router = useRouter() | |||
| const { t } = useTranslation() | |||
| const [isLoading, setIsLoading] = useState(false) | |||
| const [protocal, setProtocal] = useState('') | |||
| useEffect(() => { | |||
| const fetchFeaturesAndSetToken = async () => { | |||
| await fetchSystemFeatures().then((res) => { | |||
| setProtocal(res.sso_enforced_for_web_protocol) | |||
| }) | |||
| // Callback from SSO, process token and redirect | |||
| if (tokenFromUrl && redirectUrl) { | |||
| const appCode = redirectUrl.split('/').pop() | |||
| if (!appCode) { | |||
| Toast.notify({ | |||
| type: 'error', | |||
| message: 'redirect url is invalid. App code is not found.', | |||
| }) | |||
| return | |||
| } | |||
| await setAccessToken(appCode, tokenFromUrl) | |||
| router.push(redirectUrl) | |||
| } | |||
| } | |||
| fetchFeaturesAndSetToken() | |||
| if (message) { | |||
| Toast.notify({ | |||
| type: 'error', | |||
| message, | |||
| }) | |||
| } | |||
| }, []) | |||
| const handleSSOLogin = () => { | |||
| setIsLoading(true) | |||
| if (!redirectUrl) { | |||
| Toast.notify({ | |||
| type: 'error', | |||
| message: 'redirect url is not found.', | |||
| }) | |||
| setIsLoading(false) | |||
| return | |||
| } | |||
| const appCode = redirectUrl.split('/').pop() | |||
| if (!appCode) { | |||
| Toast.notify({ | |||
| type: 'error', | |||
| message: 'redirect url is invalid. App code is not found.', | |||
| }) | |||
| return | |||
| } | |||
| if (protocal === 'saml') { | |||
| fetchWebSAMLSSOUrl(appCode, redirectUrl).then((res) => { | |||
| router.push(res.url) | |||
| }).finally(() => { | |||
| setIsLoading(false) | |||
| }) | |||
| } | |||
| else if (protocal === 'oidc') { | |||
| fetchWebOIDCSSOUrl(appCode, redirectUrl).then((res) => { | |||
| router.push(res.url) | |||
| }).finally(() => { | |||
| setIsLoading(false) | |||
| }) | |||
| } | |||
| else { | |||
| Toast.notify({ | |||
| type: 'error', | |||
| message: 'sso protocal is not supported.', | |||
| }) | |||
| setIsLoading(false) | |||
| } | |||
| } | |||
| return ( | |||
| <div className={cn( | |||
| 'flex w-full min-h-screen', | |||
| 'sm:p-4 lg:p-8', | |||
| 'gap-x-20', | |||
| 'justify-center lg:justify-start', | |||
| )}> | |||
| <div className={ | |||
| cn( | |||
| 'flex w-full flex-col bg-white shadow rounded-2xl shrink-0', | |||
| 'space-between', | |||
| ) | |||
| }> | |||
| <div className='flex items-center justify-between p-6 w-full'> | |||
| <LogoSite /> | |||
| </div> | |||
| <div className={ | |||
| cn( | |||
| 'flex flex-col items-center w-full grow items-center justify-center', | |||
| 'px-6', | |||
| 'md:px-[108px]', | |||
| ) | |||
| }> | |||
| <div className='flex flex-col md:w-[400px]'> | |||
| <div className="w-full mx-auto"> | |||
| <h2 className="text-[32px] font-bold text-gray-900">{t('login.pageTitle')}</h2> | |||
| </div> | |||
| <div className="w-full mx-auto mt-10"> | |||
| <Button | |||
| tabIndex={0} | |||
| type='primary' | |||
| onClick={() => { handleSSOLogin() }} | |||
| disabled={isLoading} | |||
| className="w-full !fone-medium !text-sm" | |||
| >{t('login.sso')} | |||
| </Button> | |||
| </div> | |||
| </div> | |||
| </div> | |||
| </div> | |||
| </div> | |||
| ) | |||
| } | |||
| export default React.memo(WebSSOForm) | |||
| @@ -1,4 +1,6 @@ | |||
| import { CONVERSATION_ID_INFO } from '../base/chat/constants' | |||
| import { fetchAccessToken } from '@/service/share' | |||
| export const checkOrSetAccessToken = async () => { | |||
| const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] | |||
| const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) | |||
| @@ -15,3 +17,37 @@ export const checkOrSetAccessToken = async () => { | |||
| localStorage.setItem('token', JSON.stringify(accessTokenJson)) | |||
| } | |||
| } | |||
| export const setAccessToken = async (sharedToken: string, token: string) => { | |||
| const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) | |||
| let accessTokenJson = { [sharedToken]: '' } | |||
| try { | |||
| accessTokenJson = JSON.parse(accessToken) | |||
| } | |||
| catch (e) { | |||
| } | |||
| localStorage.removeItem(CONVERSATION_ID_INFO) | |||
| accessTokenJson[sharedToken] = token | |||
| localStorage.setItem('token', JSON.stringify(accessTokenJson)) | |||
| } | |||
| export const removeAccessToken = () => { | |||
| const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] | |||
| const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) | |||
| let accessTokenJson = { [sharedToken]: '' } | |||
| try { | |||
| accessTokenJson = JSON.parse(accessToken) | |||
| } | |||
| catch (e) { | |||
| } | |||
| localStorage.removeItem(CONVERSATION_ID_INFO) | |||
| delete accessTokenJson[sharedToken] | |||
| localStorage.setItem('token', JSON.stringify(accessTokenJson)) | |||
| } | |||
| @@ -6,19 +6,20 @@ import Loading from '../components/base/loading' | |||
| import Forms from './forms' | |||
| import Header from './_header' | |||
| import style from './page.module.css' | |||
| import EnterpriseSSOForm from './enterpriseSSOForm' | |||
| import UserSSOForm from './userSSOForm' | |||
| import { IS_CE_EDITION } from '@/config' | |||
| import { getEnterpriseFeatures } from '@/service/enterprise' | |||
| import type { EnterpriseFeatures } from '@/types/enterprise' | |||
| import { defaultEnterpriseFeatures } from '@/types/enterprise' | |||
| import type { SystemFeatures } from '@/types/feature' | |||
| import { defaultSystemFeatures } from '@/types/feature' | |||
| import { getSystemFeatures } from '@/service/common' | |||
| const SignIn = () => { | |||
| const [loading, setLoading] = useState<boolean>(true) | |||
| const [enterpriseFeatures, setEnterpriseFeatures] = useState<EnterpriseFeatures>(defaultEnterpriseFeatures) | |||
| const [systemFeatures, setSystemFeatures] = useState<SystemFeatures>(defaultSystemFeatures) | |||
| useEffect(() => { | |||
| getEnterpriseFeatures().then((res) => { | |||
| setEnterpriseFeatures(res) | |||
| getSystemFeatures().then((res) => { | |||
| setSystemFeatures(res) | |||
| }).finally(() => { | |||
| setLoading(false) | |||
| }) | |||
| @@ -70,7 +71,7 @@ gtag('config', 'AW-11217955271"'); | |||
| </div> | |||
| )} | |||
| {!loading && !enterpriseFeatures.sso_enforced_for_signin && ( | |||
| {!loading && !systemFeatures.sso_enforced_for_signin && ( | |||
| <> | |||
| <Forms /> | |||
| <div className='px-8 py-6 text-sm font-normal text-gray-500'> | |||
| @@ -79,8 +80,8 @@ gtag('config', 'AW-11217955271"'); | |||
| </> | |||
| )} | |||
| {!loading && enterpriseFeatures.sso_enforced_for_signin && ( | |||
| <EnterpriseSSOForm protocol={enterpriseFeatures.sso_enforced_for_signin_protocol} /> | |||
| {!loading && systemFeatures.sso_enforced_for_signin && ( | |||
| <UserSSOForm protocol={systemFeatures.sso_enforced_for_signin_protocol} /> | |||
| )} | |||
| </div> | |||
| @@ -5,14 +5,14 @@ import type { FC } from 'react' | |||
| import { useEffect, useState } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import Toast from '@/app/components/base/toast' | |||
| import { getOIDCSSOUrl, getSAMLSSOUrl } from '@/service/enterprise' | |||
| import { getUserOIDCSSOUrl, getUserSAMLSSOUrl } from '@/service/sso' | |||
| import Button from '@/app/components/base/button' | |||
| type EnterpriseSSOFormProps = { | |||
| type UserSSOFormProps = { | |||
| protocol: string | |||
| } | |||
| const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({ | |||
| const UserSSOForm: FC<UserSSOFormProps> = ({ | |||
| protocol, | |||
| }) => { | |||
| const searchParams = useSearchParams() | |||
| @@ -41,15 +41,15 @@ const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({ | |||
| const handleSSOLogin = () => { | |||
| setIsLoading(true) | |||
| if (protocol === 'saml') { | |||
| getSAMLSSOUrl().then((res) => { | |||
| getUserSAMLSSOUrl().then((res) => { | |||
| router.push(res.url) | |||
| }).finally(() => { | |||
| setIsLoading(false) | |||
| }) | |||
| } | |||
| else { | |||
| getOIDCSSOUrl().then((res) => { | |||
| document.cookie = `oidc-state=${res.state}` | |||
| getUserOIDCSSOUrl().then((res) => { | |||
| document.cookie = `user-oidc-state=${res.state}` | |||
| router.push(res.url) | |||
| }).finally(() => { | |||
| setIsLoading(false) | |||
| @@ -84,4 +84,4 @@ const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({ | |||
| ) | |||
| } | |||
| export default EnterpriseSSOForm | |||
| export default UserSSOForm | |||
| @@ -10,6 +10,7 @@ import type { | |||
| WorkflowFinishedResponse, | |||
| WorkflowStartedResponse, | |||
| } from '@/types/workflow' | |||
| import { removeAccessToken } from '@/app/components/share/utils' | |||
| const TIME_OUT = 100000 | |||
| const ContentType = { | |||
| @@ -97,6 +98,10 @@ function unicodeToChar(text: string) { | |||
| }) | |||
| } | |||
| function requiredWebSSOLogin() { | |||
| globalThis.location.href = `/webapp-signin?redirect_url=${globalThis.location.pathname}` | |||
| } | |||
| export function format(text: string) { | |||
| let res = text.trim() | |||
| if (res.startsWith('\n')) | |||
| @@ -308,6 +313,15 @@ const baseFetch = <T>( | |||
| return bodyJson.then((data: ResponseError) => { | |||
| if (!silent) | |||
| Toast.notify({ type: 'error', message: data.message }) | |||
| if (data.code === 'web_sso_auth_required') | |||
| requiredWebSSOLogin() | |||
| if (data.code === 'unauthorized') { | |||
| removeAccessToken() | |||
| globalThis.location.reload() | |||
| } | |||
| return Promise.reject(data) | |||
| }) | |||
| } | |||
| @@ -467,6 +481,16 @@ export const ssePost = ( | |||
| if (!/^(2|3)\d{2}$/.test(String(res.status))) { | |||
| res.json().then((data: any) => { | |||
| Toast.notify({ type: 'error', message: data.message || 'Server Error' }) | |||
| if (isPublicAPI) { | |||
| if (data.code === 'web_sso_auth_required') | |||
| requiredWebSSOLogin() | |||
| if (data.code === 'unauthorized') { | |||
| removeAccessToken() | |||
| globalThis.location.reload() | |||
| } | |||
| } | |||
| }) | |||
| onError?.('Server Error') | |||
| return | |||
| @@ -34,6 +34,7 @@ import type { | |||
| ModelProvider, | |||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import type { RETRIEVE_METHOD } from '@/types/app' | |||
| import type { SystemFeatures } from '@/types/feature' | |||
| export const login: Fetcher<CommonResponse & { data: string }, { url: string; body: Record<string, any> }> = ({ url, body }) => { | |||
| return post(url, { body }) as Promise<CommonResponse & { data: string }> | |||
| @@ -271,3 +272,7 @@ type RetrievalMethodsRes = { | |||
| export const fetchSupportRetrievalMethods: Fetcher<RetrievalMethodsRes, string> = (url) => { | |||
| return get<RetrievalMethodsRes>(url) | |||
| } | |||
| export const getSystemFeatures = () => { | |||
| return get<SystemFeatures>('/system-features') | |||
| } | |||
| @@ -1,14 +0,0 @@ | |||
| import { get } from './base' | |||
| import type { EnterpriseFeatures } from '@/types/enterprise' | |||
| export const getEnterpriseFeatures = () => { | |||
| return get<EnterpriseFeatures>('/enterprise-features') | |||
| } | |||
| export const getSAMLSSOUrl = () => { | |||
| return get<{ url: string }>('/enterprise/sso/saml/login') | |||
| } | |||
| export const getOIDCSSOUrl = () => { | |||
| return get<{ url: string; state: string }>('/enterprise/sso/oidc/login') | |||
| } | |||
| @@ -11,6 +11,7 @@ import type { | |||
| ConversationItem, | |||
| } from '@/models/share' | |||
| import type { ChatConfig } from '@/app/components/base/chat/types' | |||
| import type { SystemFeatures } from '@/types/feature' | |||
| function getAction(action: 'get' | 'post' | 'del' | 'patch', isInstalledApp: boolean) { | |||
| switch (action) { | |||
| @@ -135,6 +136,29 @@ export const fetchAppParams = async (isInstalledApp: boolean, installedAppId = ' | |||
| return (getAction('get', isInstalledApp))(getUrl('parameters', isInstalledApp, installedAppId)) as Promise<ChatConfig> | |||
| } | |||
| export const fetchSystemFeatures = async () => { | |||
| return (getAction('get', false))(getUrl('system-features', false, '')) as Promise<SystemFeatures> | |||
| } | |||
| export const fetchWebSAMLSSOUrl = async (appCode: string, redirectUrl: string) => { | |||
| return (getAction('get', false))(getUrl('/enterprise/sso/saml/login', false, ''), { | |||
| params: { | |||
| app_code: appCode, | |||
| redirect_url: redirectUrl, | |||
| }, | |||
| }) as Promise<{ url: string }> | |||
| } | |||
| export const fetchWebOIDCSSOUrl = async (appCode: string, redirectUrl: string) => { | |||
| return (getAction('get', false))(getUrl('/enterprise/sso/oidc/login', false, ''), { | |||
| params: { | |||
| app_code: appCode, | |||
| redirect_url: redirectUrl, | |||
| }, | |||
| }) as Promise<{ url: string }> | |||
| } | |||
| export const fetchAppMeta = async (isInstalledApp: boolean, installedAppId = '') => { | |||
| return (getAction('get', isInstalledApp))(getUrl('meta', isInstalledApp, installedAppId)) as Promise<AppMeta> | |||
| } | |||
| @@ -0,0 +1,9 @@ | |||
| import { get } from './base' | |||
| export const getUserSAMLSSOUrl = () => { | |||
| return get<{ url: string }>('/enterprise/sso/saml/login') | |||
| } | |||
| export const getUserOIDCSSOUrl = () => { | |||
| return get<{ url: string; state: string }>('/enterprise/sso/oidc/login') | |||
| } | |||
| @@ -1,9 +0,0 @@ | |||
| export type EnterpriseFeatures = { | |||
| sso_enforced_for_signin: boolean | |||
| sso_enforced_for_signin_protocol: string | |||
| } | |||
| export const defaultEnterpriseFeatures: EnterpriseFeatures = { | |||
| sso_enforced_for_signin: false, | |||
| sso_enforced_for_signin_protocol: '', | |||
| } | |||
| @@ -0,0 +1,13 @@ | |||
| export type SystemFeatures = { | |||
| sso_enforced_for_signin: boolean | |||
| sso_enforced_for_signin_protocol: string | |||
| sso_enforced_for_web: boolean | |||
| sso_enforced_for_web_protocol: string | |||
| } | |||
| export const defaultSystemFeatures: SystemFeatures = { | |||
| sso_enforced_for_signin: false, | |||
| sso_enforced_for_signin_protocol: '', | |||
| sso_enforced_for_web: false, | |||
| sso_enforced_for_web_protocol: '', | |||
| } | |||