| # Import datasets controllers | # Import datasets controllers | ||||
| from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing | from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing | ||||
| # Import enterprise controllers | |||||
| from .enterprise import enterprise_sso | |||||
| # Import explore controllers | # Import explore controllers | ||||
| from .explore import ( | from .explore import ( | ||||
| audio, | audio, |
| from flask import current_app, redirect | |||||
| from flask_restful import Resource, reqparse | |||||
| from controllers.console import api | |||||
| from controllers.console.setup import setup_required | |||||
| from services.enterprise.enterprise_sso_service import EnterpriseSSOService | |||||
| class EnterpriseSSOSamlLogin(Resource): | |||||
| @setup_required | |||||
| def get(self): | |||||
| return EnterpriseSSOService.get_sso_saml_login() | |||||
| class EnterpriseSSOSamlAcs(Resource): | |||||
| @setup_required | |||||
| def post(self): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('SAMLResponse', type=str, required=True, location='form') | |||||
| args = parser.parse_args() | |||||
| saml_response = args['SAMLResponse'] | |||||
| try: | |||||
| token = EnterpriseSSOService.post_sso_saml_acs(saml_response) | |||||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') | |||||
| except Exception as e: | |||||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') | |||||
| class EnterpriseSSOOidcLogin(Resource): | |||||
| @setup_required | |||||
| def get(self): | |||||
| return EnterpriseSSOService.get_sso_oidc_login() | |||||
| class EnterpriseSSOOidcCallback(Resource): | |||||
| @setup_required | |||||
| def get(self): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('state', type=str, required=True, location='args') | |||||
| parser.add_argument('code', type=str, required=True, location='args') | |||||
| parser.add_argument('oidc-state', type=str, required=True, location='cookies') | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| token = EnterpriseSSOService.get_sso_oidc_callback(args) | |||||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') | |||||
| except Exception as e: | |||||
| return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') | |||||
| api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login') | |||||
| api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs') | |||||
| api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login') | |||||
| api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback') |
| from flask_login import current_user | from flask_login import current_user | ||||
| from flask_restful import Resource | from flask_restful import Resource | ||||
| from services.enterprise.enterprise_feature_service import EnterpriseFeatureService | |||||
| from services.feature_service import FeatureService | from services.feature_service import FeatureService | ||||
| from . import api | from . import api | ||||
| return FeatureService.get_features(current_user.current_tenant_id).dict() | return FeatureService.get_features(current_user.current_tenant_id).dict() | ||||
| class EnterpriseFeatureApi(Resource): | |||||
| class SystemFeatureApi(Resource): | |||||
| def get(self): | def get(self): | ||||
| return EnterpriseFeatureService.get_enterprise_features().dict() | |||||
| return FeatureService.get_system_features().dict() | |||||
| api.add_resource(FeatureApi, '/features') | api.add_resource(FeatureApi, '/features') | ||||
| api.add_resource(EnterpriseFeatureApi, '/enterprise-features') | |||||
| api.add_resource(SystemFeatureApi, '/system-features') |
| api = ExternalApi(bp) | api = ExternalApi(bp) | ||||
| from . import app, audio, completion, conversation, file, message, passport, saved_message, site, workflow | |||||
| from . import app, audio, completion, conversation, feature, file, message, passport, saved_message, site, workflow |
| import json | |||||
| from flask import current_app | from flask import current_app | ||||
| from flask_restful import fields, marshal_with | from flask_restful import fields, marshal_with | ||||
| from controllers.web import api | from controllers.web import api | ||||
| from controllers.web.error import AppUnavailableError | from controllers.web.error import AppUnavailableError | ||||
| from controllers.web.wraps import WebApiResource | from controllers.web.wraps import WebApiResource | ||||
| from extensions.ext_database import db | |||||
| from models.model import App, AppMode, AppModelConfig | |||||
| from models.tools import ApiToolProvider | |||||
| from models.model import App, AppMode | |||||
| from services.app_service import AppService | from services.app_service import AppService | ||||
| error_code = 'unsupported_file_type' | error_code = 'unsupported_file_type' | ||||
| description = "File type not allowed." | description = "File type not allowed." | ||||
| code = 415 | code = 415 | ||||
| class WebSSOAuthRequiredError(BaseHTTPException): | |||||
| error_code = 'web_sso_auth_required' | |||||
| description = "Web SSO authentication required." | |||||
| code = 401 |
| from flask_restful import Resource | |||||
| from controllers.web import api | |||||
| from services.feature_service import FeatureService | |||||
| class SystemFeatureApi(Resource): | |||||
| def get(self): | |||||
| return FeatureService.get_system_features().dict() | |||||
| api.add_resource(SystemFeatureApi, '/system-features') |
| from werkzeug.exceptions import NotFound, Unauthorized | from werkzeug.exceptions import NotFound, Unauthorized | ||||
| from controllers.web import api | from controllers.web import api | ||||
| from controllers.web.error import WebSSOAuthRequiredError | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.passport import PassportService | from libs.passport import PassportService | ||||
| from models.model import App, EndUser, Site | from models.model import App, EndUser, Site | ||||
| from services.feature_service import FeatureService | |||||
| class PassportResource(Resource): | class PassportResource(Resource): | ||||
| """Base resource for passport.""" | """Base resource for passport.""" | ||||
| def get(self): | def get(self): | ||||
| system_features = FeatureService.get_system_features() | |||||
| if system_features.sso_enforced_for_web: | |||||
| raise WebSSOAuthRequiredError() | |||||
| app_code = request.headers.get('X-App-Code') | app_code = request.headers.get('X-App-Code') | ||||
| if app_code is None: | if app_code is None: | ||||
| raise Unauthorized('X-App-Code header is missing.') | raise Unauthorized('X-App-Code header is missing.') | ||||
| app_model = db.session.query(App).filter(App.id == site.app_id).first() | app_model = db.session.query(App).filter(App.id == site.app_id).first() | ||||
| if not app_model or app_model.status != 'normal' or not app_model.enable_site: | if not app_model or app_model.status != 'normal' or not app_model.enable_site: | ||||
| raise NotFound() | raise NotFound() | ||||
| end_user = EndUser( | end_user = EndUser( | ||||
| tenant_id=app_model.tenant_id, | tenant_id=app_model.tenant_id, | ||||
| app_id=app_model.id, | app_id=app_model.id, | ||||
| is_anonymous=True, | is_anonymous=True, | ||||
| session_id=generate_session_id(), | session_id=generate_session_id(), | ||||
| ) | ) | ||||
| db.session.add(end_user) | db.session.add(end_user) | ||||
| db.session.commit() | db.session.commit() | ||||
| 'access_token': tk, | 'access_token': tk, | ||||
| } | } | ||||
| api.add_resource(PassportResource, '/passport') | api.add_resource(PassportResource, '/passport') | ||||
| def generate_session_id(): | def generate_session_id(): | ||||
| """ | """ | ||||
| Generate a unique session ID. | Generate a unique session ID. |
| from flask import request | from flask import request | ||||
| from flask_restful import Resource | from flask_restful import Resource | ||||
| from werkzeug.exceptions import NotFound, Unauthorized | |||||
| from werkzeug.exceptions import BadRequest, NotFound, Unauthorized | |||||
| from controllers.web.error import WebSSOAuthRequiredError | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.passport import PassportService | from libs.passport import PassportService | ||||
| from models.model import App, EndUser, Site | from models.model import App, EndUser, Site | ||||
| from services.feature_service import FeatureService | |||||
| def validate_jwt_token(view=None): | def validate_jwt_token(view=None): | ||||
| return decorator(view) | return decorator(view) | ||||
| return decorator | return decorator | ||||
| def decode_jwt_token(): | def decode_jwt_token(): | ||||
| auth_header = request.headers.get('Authorization') | |||||
| if auth_header is None: | |||||
| raise Unauthorized('Authorization header is missing.') | |||||
| if ' ' not in auth_header: | |||||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||||
| auth_scheme, tk = auth_header.split(None, 1) | |||||
| auth_scheme = auth_scheme.lower() | |||||
| if auth_scheme != 'bearer': | |||||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||||
| decoded = PassportService().verify(tk) | |||||
| app_code = decoded.get('app_code') | |||||
| app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() | |||||
| site = db.session.query(Site).filter(Site.code == app_code).first() | |||||
| if not app_model: | |||||
| raise NotFound() | |||||
| if not app_code or not site: | |||||
| raise Unauthorized('Site URL is no longer valid.') | |||||
| if app_model.enable_site is False: | |||||
| raise Unauthorized('Site is disabled.') | |||||
| end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() | |||||
| if not end_user: | |||||
| raise NotFound() | |||||
| return app_model, end_user | |||||
| system_features = FeatureService.get_system_features() | |||||
| try: | |||||
| auth_header = request.headers.get('Authorization') | |||||
| if auth_header is None: | |||||
| raise Unauthorized('Authorization header is missing.') | |||||
| if ' ' not in auth_header: | |||||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||||
| auth_scheme, tk = auth_header.split(None, 1) | |||||
| auth_scheme = auth_scheme.lower() | |||||
| if auth_scheme != 'bearer': | |||||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||||
| decoded = PassportService().verify(tk) | |||||
| app_code = decoded.get('app_code') | |||||
| app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() | |||||
| site = db.session.query(Site).filter(Site.code == app_code).first() | |||||
| if not app_model: | |||||
| raise NotFound() | |||||
| if not app_code or not site: | |||||
| raise BadRequest('Site URL is no longer valid.') | |||||
| if app_model.enable_site is False: | |||||
| raise BadRequest('Site is disabled.') | |||||
| end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() | |||||
| if not end_user: | |||||
| raise NotFound() | |||||
| _validate_web_sso_token(decoded, system_features) | |||||
| return app_model, end_user | |||||
| except Unauthorized as e: | |||||
| if system_features.sso_enforced_for_web: | |||||
| raise WebSSOAuthRequiredError() | |||||
| raise Unauthorized(e.description) | |||||
| def _validate_web_sso_token(decoded, system_features): | |||||
| # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login | |||||
| if system_features.sso_enforced_for_web: | |||||
| source = decoded.get('token_source') | |||||
| if not source or source != 'sso': | |||||
| raise WebSSOAuthRequiredError() | |||||
| # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login | |||||
| if not system_features.sso_enforced_for_web: | |||||
| source = decoded.get('token_source') | |||||
| if source and source == 'sso': | |||||
| raise Unauthorized('sso token expired.') | |||||
| class WebApiResource(Resource): | class WebApiResource(Resource): | ||||
| method_decorators = [validate_jwt_token] | method_decorators = [validate_jwt_token] |
| from flask import current_app | |||||
| from pydantic import BaseModel | |||||
| from services.enterprise.enterprise_service import EnterpriseService | |||||
| class EnterpriseFeatureModel(BaseModel): | |||||
| sso_enforced_for_signin: bool = False | |||||
| sso_enforced_for_signin_protocol: str = '' | |||||
| class EnterpriseFeatureService: | |||||
| @classmethod | |||||
| def get_enterprise_features(cls) -> EnterpriseFeatureModel: | |||||
| features = EnterpriseFeatureModel() | |||||
| if current_app.config['ENTERPRISE_ENABLED']: | |||||
| cls._fulfill_params_from_enterprise(features) | |||||
| return features | |||||
| @classmethod | |||||
| def _fulfill_params_from_enterprise(cls, features): | |||||
| enterprise_info = EnterpriseService.get_info() | |||||
| features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] | |||||
| features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] |
| import logging | |||||
| from models.account import Account, AccountStatus | |||||
| from services.account_service import AccountService, TenantService | |||||
| from services.enterprise.base import EnterpriseRequest | |||||
| logger = logging.getLogger(__name__) | |||||
| class EnterpriseSSOService: | |||||
| @classmethod | |||||
| def get_sso_saml_login(cls) -> str: | |||||
| return EnterpriseRequest.send_request('GET', '/sso/saml/login') | |||||
| @classmethod | |||||
| def post_sso_saml_acs(cls, saml_response: str) -> str: | |||||
| response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response}) | |||||
| if 'email' not in response or response['email'] is None: | |||||
| logger.exception(response) | |||||
| raise Exception('Saml response is invalid') | |||||
| return cls.login_with_email(response.get('email')) | |||||
| @classmethod | |||||
| def get_sso_oidc_login(cls): | |||||
| return EnterpriseRequest.send_request('GET', '/sso/oidc/login') | |||||
| @classmethod | |||||
| def get_sso_oidc_callback(cls, args: dict): | |||||
| state_from_query = args['state'] | |||||
| code_from_query = args['code'] | |||||
| state_from_cookies = args['oidc-state'] | |||||
| if state_from_cookies != state_from_query: | |||||
| raise Exception('invalid state or code') | |||||
| response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query}) | |||||
| if 'email' not in response or response['email'] is None: | |||||
| logger.exception(response) | |||||
| raise Exception('OIDC response is invalid') | |||||
| return cls.login_with_email(response.get('email')) | |||||
| @classmethod | |||||
| def login_with_email(cls, email: str) -> str: | |||||
| account = Account.query.filter_by(email=email).first() | |||||
| if account is None: | |||||
| raise Exception('account not found, please contact system admin to invite you to join in a workspace') | |||||
| if account.status == AccountStatus.BANNED: | |||||
| raise Exception('account is banned, please contact system admin') | |||||
| tenants = TenantService.get_join_tenants(account) | |||||
| if len(tenants) == 0: | |||||
| raise Exception("workspace not found, please contact system admin to invite you to join in a workspace") | |||||
| token = AccountService.get_account_jwt_token(account) | |||||
| return token |
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from services.billing_service import BillingService | from services.billing_service import BillingService | ||||
| from services.enterprise.enterprise_service import EnterpriseService | |||||
| class SubscriptionModel(BaseModel): | class SubscriptionModel(BaseModel): | ||||
| can_replace_logo: bool = False | can_replace_logo: bool = False | ||||
| class SystemFeatureModel(BaseModel): | |||||
| sso_enforced_for_signin: bool = False | |||||
| sso_enforced_for_signin_protocol: str = '' | |||||
| sso_enforced_for_web: bool = False | |||||
| sso_enforced_for_web_protocol: str = '' | |||||
| class FeatureService: | class FeatureService: | ||||
| @classmethod | @classmethod | ||||
| return features | return features | ||||
| @classmethod | |||||
| def get_system_features(cls) -> SystemFeatureModel: | |||||
| system_features = SystemFeatureModel() | |||||
| if current_app.config['ENTERPRISE_ENABLED']: | |||||
| cls._fulfill_params_from_enterprise(system_features) | |||||
| return system_features | |||||
| @classmethod | @classmethod | ||||
| def _fulfill_params_from_env(cls, features: FeatureModel): | def _fulfill_params_from_env(cls, features: FeatureModel): | ||||
| features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO'] | features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO'] | ||||
| features.docs_processing = billing_info['docs_processing'] | features.docs_processing = billing_info['docs_processing'] | ||||
| features.can_replace_logo = billing_info['can_replace_logo'] | features.can_replace_logo = billing_info['can_replace_logo'] | ||||
| @classmethod | |||||
| def _fulfill_params_from_enterprise(cls, features): | |||||
| enterprise_info = EnterpriseService.get_info() | |||||
| features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] | |||||
| features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] | |||||
| features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web'] | |||||
| features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol'] |
| 'use client' | 'use client' | ||||
| import type { FC } from 'react' | import type { FC } from 'react' | ||||
| import React from 'react' | import React from 'react' | ||||
| 'use client' | |||||
| import type { FC } from 'react' | import type { FC } from 'react' | ||||
| import React from 'react' | |||||
| import React, { useEffect } from 'react' | |||||
| import cn from 'classnames' | |||||
| import type { IMainProps } from '@/app/components/share/chat' | import type { IMainProps } from '@/app/components/share/chat' | ||||
| import Main from '@/app/components/share/chatbot' | import Main from '@/app/components/share/chatbot' | ||||
| import Loading from '@/app/components/base/loading' | |||||
| import { fetchSystemFeatures } from '@/service/share' | |||||
| import LogoSite from '@/app/components/base/logo/logo-site' | |||||
| const Chatbot: FC<IMainProps> = () => { | const Chatbot: FC<IMainProps> = () => { | ||||
| const [isSSOEnforced, setIsSSOEnforced] = React.useState(true) | |||||
| const [loading, setLoading] = React.useState(true) | |||||
| useEffect(() => { | |||||
| fetchSystemFeatures().then((res) => { | |||||
| setIsSSOEnforced(res.sso_enforced_for_web) | |||||
| setLoading(false) | |||||
| }) | |||||
| }, []) | |||||
| return ( | return ( | ||||
| <Main /> | |||||
| <> | |||||
| { | |||||
| loading | |||||
| ? ( | |||||
| <div className="flex items-center justify-center h-full" > | |||||
| <div className={ | |||||
| cn( | |||||
| 'flex flex-col items-center w-full grow items-center justify-center', | |||||
| 'px-6', | |||||
| 'md:px-[108px]', | |||||
| ) | |||||
| }> | |||||
| <Loading type='area' /> | |||||
| </div> | |||||
| </div > | |||||
| ) | |||||
| : ( | |||||
| <> | |||||
| {isSSOEnforced | |||||
| ? ( | |||||
| <div className={cn( | |||||
| 'flex w-full min-h-screen', | |||||
| 'sm:p-4 lg:p-8', | |||||
| 'gap-x-20', | |||||
| 'justify-center lg:justify-start', | |||||
| )}> | |||||
| <div className={ | |||||
| cn( | |||||
| 'flex w-full flex-col bg-white shadow rounded-2xl shrink-0', | |||||
| 'space-between', | |||||
| ) | |||||
| }> | |||||
| <div className='flex items-center justify-between p-6 w-full'> | |||||
| <LogoSite /> | |||||
| </div> | |||||
| <div className={ | |||||
| cn( | |||||
| 'flex flex-col items-center w-full grow items-center justify-center', | |||||
| 'px-6', | |||||
| 'md:px-[108px]', | |||||
| ) | |||||
| }> | |||||
| <div className='flex flex-col md:w-[400px]'> | |||||
| <div className="w-full mx-auto"> | |||||
| <h2 className="text-[16px] font-bold text-gray-900"> | |||||
| Warning: Chatbot is not available | |||||
| </h2> | |||||
| <p className="text-[16px] text-gray-600 mt-2"> | |||||
| Because SSO is enforced. Please contact your administrator. | |||||
| </p> | |||||
| </div> | |||||
| </div> | |||||
| </div> | |||||
| </div> | |||||
| </div> | |||||
| ) | |||||
| : <Main /> | |||||
| } | |||||
| </> | |||||
| )} | |||||
| </> | |||||
| ) | ) | ||||
| } | } | ||||
| 'use client' | |||||
| import cn from 'classnames' | |||||
| import { useRouter, useSearchParams } from 'next/navigation' | |||||
| import type { FC } from 'react' | |||||
| import React, { useEffect, useState } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import Toast from '@/app/components/base/toast' | |||||
| import Button from '@/app/components/base/button' | |||||
| import { fetchSystemFeatures, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share' | |||||
| import LogoSite from '@/app/components/base/logo/logo-site' | |||||
| import { setAccessToken } from '@/app/components/share/utils' | |||||
| const WebSSOForm: FC = () => { | |||||
| const searchParams = useSearchParams() | |||||
| const redirectUrl = searchParams.get('redirect_url') | |||||
| const tokenFromUrl = searchParams.get('web_sso_token') | |||||
| const message = searchParams.get('message') | |||||
| const router = useRouter() | |||||
| const { t } = useTranslation() | |||||
| const [isLoading, setIsLoading] = useState(false) | |||||
| const [protocal, setProtocal] = useState('') | |||||
| useEffect(() => { | |||||
| const fetchFeaturesAndSetToken = async () => { | |||||
| await fetchSystemFeatures().then((res) => { | |||||
| setProtocal(res.sso_enforced_for_web_protocol) | |||||
| }) | |||||
| // Callback from SSO, process token and redirect | |||||
| if (tokenFromUrl && redirectUrl) { | |||||
| const appCode = redirectUrl.split('/').pop() | |||||
| if (!appCode) { | |||||
| Toast.notify({ | |||||
| type: 'error', | |||||
| message: 'redirect url is invalid. App code is not found.', | |||||
| }) | |||||
| return | |||||
| } | |||||
| await setAccessToken(appCode, tokenFromUrl) | |||||
| router.push(redirectUrl) | |||||
| } | |||||
| } | |||||
| fetchFeaturesAndSetToken() | |||||
| if (message) { | |||||
| Toast.notify({ | |||||
| type: 'error', | |||||
| message, | |||||
| }) | |||||
| } | |||||
| }, []) | |||||
| const handleSSOLogin = () => { | |||||
| setIsLoading(true) | |||||
| if (!redirectUrl) { | |||||
| Toast.notify({ | |||||
| type: 'error', | |||||
| message: 'redirect url is not found.', | |||||
| }) | |||||
| setIsLoading(false) | |||||
| return | |||||
| } | |||||
| const appCode = redirectUrl.split('/').pop() | |||||
| if (!appCode) { | |||||
| Toast.notify({ | |||||
| type: 'error', | |||||
| message: 'redirect url is invalid. App code is not found.', | |||||
| }) | |||||
| return | |||||
| } | |||||
| if (protocal === 'saml') { | |||||
| fetchWebSAMLSSOUrl(appCode, redirectUrl).then((res) => { | |||||
| router.push(res.url) | |||||
| }).finally(() => { | |||||
| setIsLoading(false) | |||||
| }) | |||||
| } | |||||
| else if (protocal === 'oidc') { | |||||
| fetchWebOIDCSSOUrl(appCode, redirectUrl).then((res) => { | |||||
| router.push(res.url) | |||||
| }).finally(() => { | |||||
| setIsLoading(false) | |||||
| }) | |||||
| } | |||||
| else { | |||||
| Toast.notify({ | |||||
| type: 'error', | |||||
| message: 'sso protocal is not supported.', | |||||
| }) | |||||
| setIsLoading(false) | |||||
| } | |||||
| } | |||||
| return ( | |||||
| <div className={cn( | |||||
| 'flex w-full min-h-screen', | |||||
| 'sm:p-4 lg:p-8', | |||||
| 'gap-x-20', | |||||
| 'justify-center lg:justify-start', | |||||
| )}> | |||||
| <div className={ | |||||
| cn( | |||||
| 'flex w-full flex-col bg-white shadow rounded-2xl shrink-0', | |||||
| 'space-between', | |||||
| ) | |||||
| }> | |||||
| <div className='flex items-center justify-between p-6 w-full'> | |||||
| <LogoSite /> | |||||
| </div> | |||||
| <div className={ | |||||
| cn( | |||||
| 'flex flex-col items-center w-full grow items-center justify-center', | |||||
| 'px-6', | |||||
| 'md:px-[108px]', | |||||
| ) | |||||
| }> | |||||
| <div className='flex flex-col md:w-[400px]'> | |||||
| <div className="w-full mx-auto"> | |||||
| <h2 className="text-[32px] font-bold text-gray-900">{t('login.pageTitle')}</h2> | |||||
| </div> | |||||
| <div className="w-full mx-auto mt-10"> | |||||
| <Button | |||||
| tabIndex={0} | |||||
| type='primary' | |||||
| onClick={() => { handleSSOLogin() }} | |||||
| disabled={isLoading} | |||||
| className="w-full !fone-medium !text-sm" | |||||
| >{t('login.sso')} | |||||
| </Button> | |||||
| </div> | |||||
| </div> | |||||
| </div> | |||||
| </div> | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| export default React.memo(WebSSOForm) |
| import { CONVERSATION_ID_INFO } from '../base/chat/constants' | |||||
| import { fetchAccessToken } from '@/service/share' | import { fetchAccessToken } from '@/service/share' | ||||
| export const checkOrSetAccessToken = async () => { | export const checkOrSetAccessToken = async () => { | ||||
| const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] | const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] | ||||
| const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) | const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) | ||||
| localStorage.setItem('token', JSON.stringify(accessTokenJson)) | localStorage.setItem('token', JSON.stringify(accessTokenJson)) | ||||
| } | } | ||||
| } | } | ||||
| export const setAccessToken = async (sharedToken: string, token: string) => { | |||||
| const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) | |||||
| let accessTokenJson = { [sharedToken]: '' } | |||||
| try { | |||||
| accessTokenJson = JSON.parse(accessToken) | |||||
| } | |||||
| catch (e) { | |||||
| } | |||||
| localStorage.removeItem(CONVERSATION_ID_INFO) | |||||
| accessTokenJson[sharedToken] = token | |||||
| localStorage.setItem('token', JSON.stringify(accessTokenJson)) | |||||
| } | |||||
| export const removeAccessToken = () => { | |||||
| const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] | |||||
| const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) | |||||
| let accessTokenJson = { [sharedToken]: '' } | |||||
| try { | |||||
| accessTokenJson = JSON.parse(accessToken) | |||||
| } | |||||
| catch (e) { | |||||
| } | |||||
| localStorage.removeItem(CONVERSATION_ID_INFO) | |||||
| delete accessTokenJson[sharedToken] | |||||
| localStorage.setItem('token', JSON.stringify(accessTokenJson)) | |||||
| } |
| import Forms from './forms' | import Forms from './forms' | ||||
| import Header from './_header' | import Header from './_header' | ||||
| import style from './page.module.css' | import style from './page.module.css' | ||||
| import EnterpriseSSOForm from './enterpriseSSOForm' | |||||
| import UserSSOForm from './userSSOForm' | |||||
| import { IS_CE_EDITION } from '@/config' | import { IS_CE_EDITION } from '@/config' | ||||
| import { getEnterpriseFeatures } from '@/service/enterprise' | |||||
| import type { EnterpriseFeatures } from '@/types/enterprise' | |||||
| import { defaultEnterpriseFeatures } from '@/types/enterprise' | |||||
| import type { SystemFeatures } from '@/types/feature' | |||||
| import { defaultSystemFeatures } from '@/types/feature' | |||||
| import { getSystemFeatures } from '@/service/common' | |||||
| const SignIn = () => { | const SignIn = () => { | ||||
| const [loading, setLoading] = useState<boolean>(true) | const [loading, setLoading] = useState<boolean>(true) | ||||
| const [enterpriseFeatures, setEnterpriseFeatures] = useState<EnterpriseFeatures>(defaultEnterpriseFeatures) | |||||
| const [systemFeatures, setSystemFeatures] = useState<SystemFeatures>(defaultSystemFeatures) | |||||
| useEffect(() => { | useEffect(() => { | ||||
| getEnterpriseFeatures().then((res) => { | |||||
| setEnterpriseFeatures(res) | |||||
| getSystemFeatures().then((res) => { | |||||
| setSystemFeatures(res) | |||||
| }).finally(() => { | }).finally(() => { | ||||
| setLoading(false) | setLoading(false) | ||||
| }) | }) | ||||
| </div> | </div> | ||||
| )} | )} | ||||
| {!loading && !enterpriseFeatures.sso_enforced_for_signin && ( | |||||
| {!loading && !systemFeatures.sso_enforced_for_signin && ( | |||||
| <> | <> | ||||
| <Forms /> | <Forms /> | ||||
| <div className='px-8 py-6 text-sm font-normal text-gray-500'> | <div className='px-8 py-6 text-sm font-normal text-gray-500'> | ||||
| </> | </> | ||||
| )} | )} | ||||
| {!loading && enterpriseFeatures.sso_enforced_for_signin && ( | |||||
| <EnterpriseSSOForm protocol={enterpriseFeatures.sso_enforced_for_signin_protocol} /> | |||||
| {!loading && systemFeatures.sso_enforced_for_signin && ( | |||||
| <UserSSOForm protocol={systemFeatures.sso_enforced_for_signin_protocol} /> | |||||
| )} | )} | ||||
| </div> | </div> | ||||
| import { useEffect, useState } from 'react' | import { useEffect, useState } from 'react' | ||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import Toast from '@/app/components/base/toast' | import Toast from '@/app/components/base/toast' | ||||
| import { getOIDCSSOUrl, getSAMLSSOUrl } from '@/service/enterprise' | |||||
| import { getUserOIDCSSOUrl, getUserSAMLSSOUrl } from '@/service/sso' | |||||
| import Button from '@/app/components/base/button' | import Button from '@/app/components/base/button' | ||||
| type EnterpriseSSOFormProps = { | |||||
| type UserSSOFormProps = { | |||||
| protocol: string | protocol: string | ||||
| } | } | ||||
| const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({ | |||||
| const UserSSOForm: FC<UserSSOFormProps> = ({ | |||||
| protocol, | protocol, | ||||
| }) => { | }) => { | ||||
| const searchParams = useSearchParams() | const searchParams = useSearchParams() | ||||
| const handleSSOLogin = () => { | const handleSSOLogin = () => { | ||||
| setIsLoading(true) | setIsLoading(true) | ||||
| if (protocol === 'saml') { | if (protocol === 'saml') { | ||||
| getSAMLSSOUrl().then((res) => { | |||||
| getUserSAMLSSOUrl().then((res) => { | |||||
| router.push(res.url) | router.push(res.url) | ||||
| }).finally(() => { | }).finally(() => { | ||||
| setIsLoading(false) | setIsLoading(false) | ||||
| }) | }) | ||||
| } | } | ||||
| else { | else { | ||||
| getOIDCSSOUrl().then((res) => { | |||||
| document.cookie = `oidc-state=${res.state}` | |||||
| getUserOIDCSSOUrl().then((res) => { | |||||
| document.cookie = `user-oidc-state=${res.state}` | |||||
| router.push(res.url) | router.push(res.url) | ||||
| }).finally(() => { | }).finally(() => { | ||||
| setIsLoading(false) | setIsLoading(false) | ||||
| ) | ) | ||||
| } | } | ||||
| export default EnterpriseSSOForm | |||||
| export default UserSSOForm |
| WorkflowFinishedResponse, | WorkflowFinishedResponse, | ||||
| WorkflowStartedResponse, | WorkflowStartedResponse, | ||||
| } from '@/types/workflow' | } from '@/types/workflow' | ||||
| import { removeAccessToken } from '@/app/components/share/utils' | |||||
| const TIME_OUT = 100000 | const TIME_OUT = 100000 | ||||
| const ContentType = { | const ContentType = { | ||||
| }) | }) | ||||
| } | } | ||||
| function requiredWebSSOLogin() { | |||||
| globalThis.location.href = `/webapp-signin?redirect_url=${globalThis.location.pathname}` | |||||
| } | |||||
| export function format(text: string) { | export function format(text: string) { | ||||
| let res = text.trim() | let res = text.trim() | ||||
| if (res.startsWith('\n')) | if (res.startsWith('\n')) | ||||
| return bodyJson.then((data: ResponseError) => { | return bodyJson.then((data: ResponseError) => { | ||||
| if (!silent) | if (!silent) | ||||
| Toast.notify({ type: 'error', message: data.message }) | Toast.notify({ type: 'error', message: data.message }) | ||||
| if (data.code === 'web_sso_auth_required') | |||||
| requiredWebSSOLogin() | |||||
| if (data.code === 'unauthorized') { | |||||
| removeAccessToken() | |||||
| globalThis.location.reload() | |||||
| } | |||||
| return Promise.reject(data) | return Promise.reject(data) | ||||
| }) | }) | ||||
| } | } | ||||
| if (!/^(2|3)\d{2}$/.test(String(res.status))) { | if (!/^(2|3)\d{2}$/.test(String(res.status))) { | ||||
| res.json().then((data: any) => { | res.json().then((data: any) => { | ||||
| Toast.notify({ type: 'error', message: data.message || 'Server Error' }) | Toast.notify({ type: 'error', message: data.message || 'Server Error' }) | ||||
| if (isPublicAPI) { | |||||
| if (data.code === 'web_sso_auth_required') | |||||
| requiredWebSSOLogin() | |||||
| if (data.code === 'unauthorized') { | |||||
| removeAccessToken() | |||||
| globalThis.location.reload() | |||||
| } | |||||
| } | |||||
| }) | }) | ||||
| onError?.('Server Error') | onError?.('Server Error') | ||||
| return | return |
| ModelProvider, | ModelProvider, | ||||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | } from '@/app/components/header/account-setting/model-provider-page/declarations' | ||||
| import type { RETRIEVE_METHOD } from '@/types/app' | import type { RETRIEVE_METHOD } from '@/types/app' | ||||
| import type { SystemFeatures } from '@/types/feature' | |||||
| export const login: Fetcher<CommonResponse & { data: string }, { url: string; body: Record<string, any> }> = ({ url, body }) => { | export const login: Fetcher<CommonResponse & { data: string }, { url: string; body: Record<string, any> }> = ({ url, body }) => { | ||||
| return post(url, { body }) as Promise<CommonResponse & { data: string }> | return post(url, { body }) as Promise<CommonResponse & { data: string }> | ||||
| export const fetchSupportRetrievalMethods: Fetcher<RetrievalMethodsRes, string> = (url) => { | export const fetchSupportRetrievalMethods: Fetcher<RetrievalMethodsRes, string> = (url) => { | ||||
| return get<RetrievalMethodsRes>(url) | return get<RetrievalMethodsRes>(url) | ||||
| } | } | ||||
| export const getSystemFeatures = () => { | |||||
| return get<SystemFeatures>('/system-features') | |||||
| } |
| import { get } from './base' | |||||
| import type { EnterpriseFeatures } from '@/types/enterprise' | |||||
| export const getEnterpriseFeatures = () => { | |||||
| return get<EnterpriseFeatures>('/enterprise-features') | |||||
| } | |||||
| export const getSAMLSSOUrl = () => { | |||||
| return get<{ url: string }>('/enterprise/sso/saml/login') | |||||
| } | |||||
| export const getOIDCSSOUrl = () => { | |||||
| return get<{ url: string; state: string }>('/enterprise/sso/oidc/login') | |||||
| } |
| ConversationItem, | ConversationItem, | ||||
| } from '@/models/share' | } from '@/models/share' | ||||
| import type { ChatConfig } from '@/app/components/base/chat/types' | import type { ChatConfig } from '@/app/components/base/chat/types' | ||||
| import type { SystemFeatures } from '@/types/feature' | |||||
| function getAction(action: 'get' | 'post' | 'del' | 'patch', isInstalledApp: boolean) { | function getAction(action: 'get' | 'post' | 'del' | 'patch', isInstalledApp: boolean) { | ||||
| switch (action) { | switch (action) { | ||||
| return (getAction('get', isInstalledApp))(getUrl('parameters', isInstalledApp, installedAppId)) as Promise<ChatConfig> | return (getAction('get', isInstalledApp))(getUrl('parameters', isInstalledApp, installedAppId)) as Promise<ChatConfig> | ||||
| } | } | ||||
| export const fetchSystemFeatures = async () => { | |||||
| return (getAction('get', false))(getUrl('system-features', false, '')) as Promise<SystemFeatures> | |||||
| } | |||||
| export const fetchWebSAMLSSOUrl = async (appCode: string, redirectUrl: string) => { | |||||
| return (getAction('get', false))(getUrl('/enterprise/sso/saml/login', false, ''), { | |||||
| params: { | |||||
| app_code: appCode, | |||||
| redirect_url: redirectUrl, | |||||
| }, | |||||
| }) as Promise<{ url: string }> | |||||
| } | |||||
| export const fetchWebOIDCSSOUrl = async (appCode: string, redirectUrl: string) => { | |||||
| return (getAction('get', false))(getUrl('/enterprise/sso/oidc/login', false, ''), { | |||||
| params: { | |||||
| app_code: appCode, | |||||
| redirect_url: redirectUrl, | |||||
| }, | |||||
| }) as Promise<{ url: string }> | |||||
| } | |||||
| export const fetchAppMeta = async (isInstalledApp: boolean, installedAppId = '') => { | export const fetchAppMeta = async (isInstalledApp: boolean, installedAppId = '') => { | ||||
| return (getAction('get', isInstalledApp))(getUrl('meta', isInstalledApp, installedAppId)) as Promise<AppMeta> | return (getAction('get', isInstalledApp))(getUrl('meta', isInstalledApp, installedAppId)) as Promise<AppMeta> | ||||
| } | } |
| import { get } from './base' | |||||
| export const getUserSAMLSSOUrl = () => { | |||||
| return get<{ url: string }>('/enterprise/sso/saml/login') | |||||
| } | |||||
| export const getUserOIDCSSOUrl = () => { | |||||
| return get<{ url: string; state: string }>('/enterprise/sso/oidc/login') | |||||
| } |
| export type EnterpriseFeatures = { | |||||
| sso_enforced_for_signin: boolean | |||||
| sso_enforced_for_signin_protocol: string | |||||
| } | |||||
| export const defaultEnterpriseFeatures: EnterpriseFeatures = { | |||||
| sso_enforced_for_signin: false, | |||||
| sso_enforced_for_signin_protocol: '', | |||||
| } |
| export type SystemFeatures = { | |||||
| sso_enforced_for_signin: boolean | |||||
| sso_enforced_for_signin_protocol: string | |||||
| sso_enforced_for_web: boolean | |||||
| sso_enforced_for_web_protocol: string | |||||
| } | |||||
| export const defaultSystemFeatures: SystemFeatures = { | |||||
| sso_enforced_for_signin: false, | |||||
| sso_enforced_for_signin_protocol: '', | |||||
| sso_enforced_for_web: false, | |||||
| sso_enforced_for_web_protocol: '', | |||||
| } |