瀏覽代碼

Feat/api jwt (#1212)

tags/0.3.24
zxhlyh 2 年之前
父節點
當前提交
227f9fb77d
沒有連結到貢獻者的電子郵件帳戶。

+ 0
- 18
api/.env.example 查看文件

WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*


# Cookie configuration
COOKIE_HTTPONLY=true
COOKIE_SAMESITE=None
COOKIE_SECURE=true

# Session configuration
SESSION_PERMANENT=true
SESSION_USE_SIGNER=true

## support redis, sqlalchemy
SESSION_TYPE=redis

# session redis configuration
SESSION_REDIS_HOST=localhost
SESSION_REDIS_PORT=6379
SESSION_REDIS_PASSWORD=difyai123456
SESSION_REDIS_DB=2

# Vector database configuration, support: weaviate, qdrant # Vector database configuration, support: weaviate, qdrant
VECTOR_STORE=weaviate VECTOR_STORE=weaviate



+ 21
- 72
api/app.py 查看文件

# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import os import os
from datetime import datetime, timedelta


from werkzeug.exceptions import Forbidden
from werkzeug.exceptions import Unauthorized


if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey from gevent import monkey
import json import json
import threading import threading


from flask import Flask, request, Response, session
import flask_login
from flask import Flask, request, Response
from flask_cors import CORS from flask_cors import CORS


from core.model_providers.providers import hosted from core.model_providers.providers import hosted
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_stripe ext_database, ext_storage, ext_mail, ext_stripe
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_login import login_manager from extensions.ext_login import login_manager
from events import event_handlers from events import event_handlers
# DO NOT REMOVE ABOVE # DO NOT REMOVE ABOVE


import core
from config import Config, CloudEditionConfig from config import Config, CloudEditionConfig
from commands import register_commands from commands import register_commands
from models.account import TenantAccountJoin, AccountStatus
from models.model import Account, EndUser, App
from services.account_service import TenantService
from services.account_service import AccountService
from libs.passport import PassportService


import warnings import warnings
warnings.simplefilter("ignore", ResourceWarning) warnings.simplefilter("ignore", ResourceWarning)
ext_redis.init_app(app) ext_redis.init_app(app)
ext_storage.init_app(app) ext_storage.init_app(app)
ext_celery.init_app(app) ext_celery.init_app(app)
ext_session.init_app(app)
ext_login.init_app(app) ext_login.init_app(app)
ext_mail.init_app(app) ext_mail.init_app(app)
ext_sentry.init_app(app) ext_sentry.init_app(app)
ext_stripe.init_app(app) ext_stripe.init_app(app)




def _create_tenant_for_account(account):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")

TenantService.create_tenant_member(tenant, account, role='owner')
account.current_tenant = tenant

return tenant


# Flask-Login configuration # Flask-Login configuration
@login_manager.user_loader
def load_user(user_id):
"""Load user based on the user_id."""
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint == 'console': if request.blueprint == 'console':
# Check if the user_id contains a dot, indicating the old format # Check if the user_id contains a dot, indicating the old format
if '.' in user_id:
tenant_id, account_id = user_id.split('.')
else:
account_id = user_id

account = db.session.query(Account).filter(Account.id == account_id).first()

if account:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise Forbidden('Account is banned or closed.')

workspace_id = session.get('workspace_id')
if workspace_id:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == workspace_id
).first()

if not tenant_account_join:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()

if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
else:
account.current_tenant_id = workspace_id
else:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id

current_time = datetime.utcnow()

# update last_active_at when last_active_at is more than 10 minutes ago
if current_time - account.last_active_at > timedelta(minutes=10):
account.last_active_at = current_time
db.session.commit()

# Log in the user with the updated user_id
flask_login.login_user(account, remember=True)

return account
auth_header = request.headers.get('Authorization', '')
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')

return AccountService.load_user(user_id)
else: else:
return None return None



@login_manager.unauthorized_handler @login_manager.unauthorized_handler
def unauthorized_handler(): def unauthorized_handler():
"""Handle unauthorized requests.""" """Handle unauthorized requests."""
@app.after_request @app.after_request
def after_request(response): def after_request(response):
"""Add Version headers to the response.""" """Add Version headers to the response."""
response.set_cookie('remember_token', '', expires=0)
response.headers.add('X-Version', app.config['CURRENT_VERSION']) response.headers.add('X-Version', app.config['CURRENT_VERSION'])
response.headers.add('X-Env', app.config['DEPLOY_ENV']) response.headers.add('X-Env', app.config['DEPLOY_ENV'])
return response return response

+ 0
- 32
api/config.py 查看文件

dotenv.load_dotenv() dotenv.load_dotenv()


DEFAULTS = { DEFAULTS = {
'COOKIE_HTTPONLY': 'True',
'COOKIE_SECURE': 'True',
'COOKIE_SAMESITE': 'None',
'DB_USERNAME': 'postgres', 'DB_USERNAME': 'postgres',
'DB_PASSWORD': '', 'DB_PASSWORD': '',
'DB_HOST': 'localhost', 'DB_HOST': 'localhost',
'REDIS_PORT': '6379', 'REDIS_PORT': '6379',
'REDIS_DB': '0', 'REDIS_DB': '0',
'REDIS_USE_SSL': 'False', 'REDIS_USE_SSL': 'False',
'SESSION_REDIS_HOST': 'localhost',
'SESSION_REDIS_PORT': '6379',
'SESSION_REDIS_DB': '2',
'SESSION_REDIS_USE_SSL': 'False',
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize', 'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
'OAUTH_REDIRECT_INDEX_PATH': '/', 'OAUTH_REDIRECT_INDEX_PATH': '/',
'CONSOLE_WEB_URL': 'https://cloud.dify.ai', 'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
'STORAGE_TYPE': 'local', 'STORAGE_TYPE': 'local',
'STORAGE_LOCAL_PATH': 'storage', 'STORAGE_LOCAL_PATH': 'storage',
'CHECK_UPDATE_URL': 'https://updates.dify.ai', 'CHECK_UPDATE_URL': 'https://updates.dify.ai',
'SESSION_TYPE': 'sqlalchemy',
'SESSION_PERMANENT': 'True',
'SESSION_USE_SIGNER': 'True',
'DEPLOY_ENV': 'PRODUCTION', 'DEPLOY_ENV': 'PRODUCTION',
'SQLALCHEMY_POOL_SIZE': 30, 'SQLALCHEMY_POOL_SIZE': 30,
'SQLALCHEMY_POOL_RECYCLE': 3600, 'SQLALCHEMY_POOL_RECYCLE': 3600,
# Alternatively you can set it with `SECRET_KEY` environment variable. # Alternatively you can set it with `SECRET_KEY` environment variable.
self.SECRET_KEY = get_env('SECRET_KEY') self.SECRET_KEY = get_env('SECRET_KEY')


# cookie settings
self.REMEMBER_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
self.SESSION_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
self.REMEMBER_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
self.SESSION_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
self.REMEMBER_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
self.SESSION_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
self.PERMANENT_SESSION_LIFETIME = timedelta(days=7)

# session settings, only support sqlalchemy, redis
self.SESSION_TYPE = get_env('SESSION_TYPE')
self.SESSION_PERMANENT = get_bool_env('SESSION_PERMANENT')
self.SESSION_USE_SIGNER = get_bool_env('SESSION_USE_SIGNER')

# redis settings # redis settings
self.REDIS_HOST = get_env('REDIS_HOST') self.REDIS_HOST = get_env('REDIS_HOST')
self.REDIS_PORT = get_env('REDIS_PORT') self.REDIS_PORT = get_env('REDIS_PORT')
self.REDIS_DB = get_env('REDIS_DB') self.REDIS_DB = get_env('REDIS_DB')
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL') self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')


# session redis settings
self.SESSION_REDIS_HOST = get_env('SESSION_REDIS_HOST')
self.SESSION_REDIS_PORT = get_env('SESSION_REDIS_PORT')
self.SESSION_REDIS_USERNAME = get_env('SESSION_REDIS_USERNAME')
self.SESSION_REDIS_PASSWORD = get_env('SESSION_REDIS_PASSWORD')
self.SESSION_REDIS_DB = get_env('SESSION_REDIS_DB')
self.SESSION_REDIS_USE_SSL = get_bool_env('SESSION_REDIS_USE_SSL')

# storage settings # storage settings
self.STORAGE_TYPE = get_env('STORAGE_TYPE') self.STORAGE_TYPE = get_env('STORAGE_TYPE')
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH') self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')

+ 2
- 3
api/controllers/console/auth/login.py 查看文件



import services import services
from controllers.console import api from controllers.console import api
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from libs.helper import email from libs.helper import email
from libs.password import valid_password from libs.password import valid_password
except Exception: except Exception:
pass pass


flask_login.login_user(account, remember=args['remember_me'])
AccountService.update_last_login(account, request) AccountService.update_last_login(account, request)


# todo: return the user info # todo: return the user info
token = AccountService.get_account_jwt_token(account)


return {'result': 'success'}
return {'result': 'success', 'data': token}




class LogoutApi(Resource): class LogoutApi(Resource):

+ 4
- 6
api/controllers/console/auth/oauth.py 查看文件

from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional


import flask_login
import requests import requests
from flask import request, redirect, current_app, session
from flask import request, redirect, current_app
from flask_restful import Resource from flask_restful import Resource


from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth
account.initialized_at = datetime.utcnow() account.initialized_at = datetime.utcnow()
db.session.commit() db.session.commit()


# login user
session.clear()
flask_login.login_user(account, remember=True)
AccountService.update_last_login(account, request) AccountService.update_last_login(account, request)


return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success')
token = AccountService.get_account_jwt_token(account)

return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')




def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:

+ 0
- 4
api/controllers/console/setup.py 查看文件

# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from functools import wraps from functools import wraps


import flask_login
from flask import request, current_app from flask import request, current_app
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse


) )


setup() setup()

# Login
flask_login.login_user(account)
AccountService.update_last_login(account, request) AccountService.update_last_login(account, request)


return {'result': 'success'}, 201 return {'result': 'success'}, 201

+ 1
- 2
api/core/login/login.py 查看文件

import os import os
from functools import wraps from functools import wraps


import flask_login
from flask import current_app from flask import current_app
from flask import g from flask import g
from flask import has_request_context from flask import has_request_context
from flask import request
from flask import request, session
from flask_login import user_logged_in from flask_login import user_logged_in
from flask_login.config import EXEMPT_METHODS from flask_login.config import EXEMPT_METHODS
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized

+ 0
- 174
api/extensions/ext_session.py 查看文件

import redis
from redis.connection import SSLConnection, Connection
from flask import request
from flask_session import Session, SqlAlchemySessionInterface, RedisSessionInterface
from flask_session.sessions import total_seconds
from itsdangerous import want_bytes

from extensions.ext_database import db

sess = Session()


def init_app(app):
sqlalchemy_session_interface = CustomSqlAlchemySessionInterface(
app,
db,
app.config.get('SESSION_SQLALCHEMY_TABLE', 'sessions'),
app.config.get('SESSION_KEY_PREFIX', 'session:'),
app.config.get('SESSION_USE_SIGNER', False),
app.config.get('SESSION_PERMANENT', True)
)

session_type = app.config.get('SESSION_TYPE')
if session_type == 'sqlalchemy':
app.session_interface = sqlalchemy_session_interface
elif session_type == 'redis':
connection_class = Connection
if app.config.get('SESSION_REDIS_USE_SSL', False):
connection_class = SSLConnection

sess_redis_client = redis.Redis()
sess_redis_client.connection_pool = redis.ConnectionPool(**{
'host': app.config.get('SESSION_REDIS_HOST', 'localhost'),
'port': app.config.get('SESSION_REDIS_PORT', 6379),
'username': app.config.get('SESSION_REDIS_USERNAME', None),
'password': app.config.get('SESSION_REDIS_PASSWORD', None),
'db': app.config.get('SESSION_REDIS_DB', 2),
'encoding': 'utf-8',
'encoding_errors': 'strict',
'decode_responses': False
}, connection_class=connection_class)

app.extensions['session_redis'] = sess_redis_client

app.session_interface = CustomRedisSessionInterface(
sess_redis_client,
app.config.get('SESSION_KEY_PREFIX', 'session:'),
app.config.get('SESSION_USE_SIGNER', False),
app.config.get('SESSION_PERMANENT', True)
)


class CustomSqlAlchemySessionInterface(SqlAlchemySessionInterface):

def __init__(
self,
app,
db,
table,
key_prefix,
use_signer=False,
permanent=True,
sequence=None,
autodelete=False,
):
if db is None:
from flask_sqlalchemy import SQLAlchemy

db = SQLAlchemy(app)
self.db = db
self.key_prefix = key_prefix
self.use_signer = use_signer
self.permanent = permanent
self.autodelete = autodelete
self.sequence = sequence
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")

class Session(self.db.Model):
__tablename__ = table

if sequence:
id = self.db.Column( # noqa: A003, VNE003, A001
self.db.Integer, self.db.Sequence(sequence), primary_key=True
)
else:
id = self.db.Column( # noqa: A003, VNE003, A001
self.db.Integer, primary_key=True
)

session_id = self.db.Column(self.db.String(255), unique=True)
data = self.db.Column(self.db.LargeBinary)
expiry = self.db.Column(self.db.DateTime)

def __init__(self, session_id, data, expiry):
self.session_id = session_id
self.data = data
self.expiry = expiry

def __repr__(self):
return f"<Session data {self.data}>"

self.sql_session_model = Session

def save_session(self, *args, **kwargs):
if request.blueprint == 'service_api':
return
elif request.method == 'OPTIONS':
return
elif request.endpoint and request.endpoint == 'health':
return
return super().save_session(*args, **kwargs)


class CustomRedisSessionInterface(RedisSessionInterface):

def save_session(self, app, session, response):
if request.blueprint == 'service_api':
return
elif request.method == 'OPTIONS':
return
elif request.endpoint and request.endpoint == 'health':
return

if not self.should_set_cookie(app, session):
return
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
if not session:
if session.modified:
self.redis.delete(self.key_prefix + session.sid)
response.delete_cookie(
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
)
return

# Modification case. There are upsides and downsides to
# emitting a set-cookie header each request. The behavior
# is controlled by the :meth:`should_set_cookie` method
# which performs a quick check to figure out if the cookie
# should be set or not. This is controlled by the
# SESSION_REFRESH_EACH_REQUEST config flag as well as
# the permanent flag on the session itself.
# if not self.should_set_cookie(app, session):
# return
conditional_cookie_kwargs = {}
httponly = self.get_cookie_httponly(app)
secure = self.get_cookie_secure(app)
if self.has_same_site_capability:
conditional_cookie_kwargs["samesite"] = self.get_cookie_samesite(app)
expires = self.get_expiration_time(app, session)

if session.permanent:
value = self.serializer.dumps(dict(session))
if value is not None:
self.redis.setex(
name=self.key_prefix + session.sid,
value=value,
time=total_seconds(app.permanent_session_lifetime),
)

if self.use_signer:
session_id = self._get_signer(app).sign(want_bytes(session.sid)).decode("utf-8")
else:
session_id = session.sid
response.set_cookie(
app.config["SESSION_COOKIE_NAME"],
session_id,
expires=expires,
httponly=httponly,
domain=domain,
path=path,
secure=secure,
**conditional_cookie_kwargs,
)

+ 71
- 4
api/services/account_service.py 查看文件

import logging import logging
import secrets import secrets
import uuid import uuid
from datetime import datetime
from datetime import datetime, timedelta
from hashlib import sha256 from hashlib import sha256
from typing import Optional from typing import Optional


from flask import session
from werkzeug.exceptions import Forbidden, Unauthorized
from flask import session, current_app
from sqlalchemy import func from sqlalchemy import func


from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from libs.helper import get_remote_ip from libs.helper import get_remote_ip
from libs.password import compare_password, hash_password from libs.password import compare_password, hash_password
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from libs.passport import PassportService
from models.account import * from models.account import *
from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task


def _create_tenant_for_account(account):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")

TenantService.create_tenant_member(tenant, account, role='owner')
account.current_tenant = tenant

return tenant



class AccountService: class AccountService:


@staticmethod @staticmethod
def load_user(account_id: int) -> Account:
def load_user(user_id: str) -> Account:
# todo: used by flask_login # todo: used by flask_login
pass
if '.' in user_id:
tenant_id, account_id = user_id.split('.')
else:
account_id = user_id

account = db.session.query(Account).filter(Account.id == account_id).first()

if account:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise Forbidden('Account is banned or closed.')

workspace_id = session.get('workspace_id')
if workspace_id:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == workspace_id
).first()

if not tenant_account_join:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()

if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
else:
account.current_tenant_id = workspace_id
else:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id

current_time = datetime.utcnow()

# update last_active_at when last_active_at is more than 10 minutes ago
if current_time - account.last_active_at > timedelta(minutes=10):
account.last_active_at = current_time
db.session.commit()

return account
@staticmethod
def get_account_jwt_token(account):
payload = {
"user_id": account.id,
"exp": datetime.utcnow() + timedelta(days=30),
"iss": current_app.config['EDITION'],
"sub": 'Console API Passport',
}

token = PassportService().issue(payload)
return token


@staticmethod @staticmethod
def authenticate(email: str, password: str) -> Account: def authenticate(email: str, password: str) -> Account:

+ 0
- 13
docker/docker-compose.yaml 查看文件

REDIS_USE_SSL: 'false' REDIS_USE_SSL: 'false'
# use redis db 0 for redis cache # use redis db 0 for redis cache
REDIS_DB: 0 REDIS_DB: 0
# The configurations of session, Supported values are `sqlalchemy`. `redis`
SESSION_TYPE: redis
SESSION_REDIS_HOST: redis
SESSION_REDIS_PORT: 6379
SESSION_REDIS_USERNAME: ''
SESSION_REDIS_PASSWORD: difyai123456
SESSION_REDIS_USE_SSL: 'false'
# use redis db 2 for session store
SESSION_REDIS_DB: 2
# The configurations of celery broker. # The configurations of celery broker.
# Use redis as the broker, and redis db 1 for celery broker. # Use redis as the broker, and redis db 1 for celery broker.
CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1 CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1
# If you want to enable cross-origin support, # If you want to enable cross-origin support,
# you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`. # you must use the HTTPS protocol and set the configuration to `SameSite=None, Secure=true, HttpOnly=true`.
# #
# For **production** purposes, please set `SameSite=Lax, Secure=true, HttpOnly=true`.
COOKIE_HTTPONLY: 'true'
COOKIE_SAMESITE: 'Lax'
COOKIE_SECURE: 'false'
# The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local` # The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local`
STORAGE_TYPE: local STORAGE_TYPE: local
# The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`. # The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`.

+ 28
- 7
web/app/components/swr-initor.tsx 查看文件

'use client' 'use client'


import { SWRConfig } from 'swr' import { SWRConfig } from 'swr'
import { useEffect, useState } from 'react'
import type { ReactNode } from 'react' import type { ReactNode } from 'react'
import { useRouter, useSearchParams } from 'next/navigation'


type SwrInitorProps = { type SwrInitorProps = {
children: ReactNode children: ReactNode
const SwrInitor = ({ const SwrInitor = ({
children, children,
}: SwrInitorProps) => { }: SwrInitorProps) => {
return (
<SWRConfig value={{
shouldRetryOnError: false,
}}>
{children}
</SWRConfig>
)
const router = useRouter()
const searchParams = useSearchParams()
const consoleToken = searchParams.get('console_token')
const consoleTokenFromLocalStorage = localStorage?.getItem('console_token')
const [init, setInit] = useState(false)

useEffect(() => {
if (!(consoleToken || consoleTokenFromLocalStorage))
router.replace('/signin')

if (consoleToken) {
localStorage?.setItem('console_token', consoleToken!)
router.replace('/apps', { forceOptimisticNavigation: false })
}
setInit(true)
}, [])

return init
? (
<SWRConfig value={{
shouldRetryOnError: false,
}}>
{children}
</SWRConfig>
)
: null
} }


export default SwrInitor export default SwrInitor

+ 4
- 0
web/app/signin/_header.tsx 查看文件



const Header = () => { const Header = () => {
const { locale, setLocaleOnClient } = useContext(I18n) const { locale, setLocaleOnClient } = useContext(I18n)

if (localStorage?.getItem('console_token'))
localStorage.removeItem('console_token')

return <div className='flex items-center justify-between p-6 w-full'> return <div className='flex items-center justify-between p-6 w-full'>
<div className={style.logo}></div> <div className={style.logo}></div>
<Select <Select

+ 3
- 2
web/app/signin/normalForm.tsx 查看文件

} }
try { try {
setIsLoading(true) setIsLoading(true)
await login({
const res = await login({
url: '/login', url: '/login',
body: { body: {
email, email,
remember_me: true, remember_me: true,
}, },
}) })
router.push('/apps')
localStorage.setItem('console_token', res.data)
router.replace('/apps')
} }
finally { finally {
setIsLoading(false) setIsLoading(false)

+ 7
- 1
web/service/base.ts 查看文件

} }
options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`) options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`)
} }
else {
const accessToken = localStorage.getItem('console_token') || ''
options.headers.set('Authorization', `Bearer ${accessToken}`)
}


if (deleteContentType) { if (deleteContentType) {
options.headers.delete('Content-Type') options.headers.delete('Content-Type')
const defaultOptions = { const defaultOptions = {
method: 'POST', method: 'POST',
url: `${API_PREFIX}/files/upload`, url: `${API_PREFIX}/files/upload`,
headers: {},
headers: {
Authorization: `Bearer ${localStorage.getItem('console_token') || ''}`,
},
data: {}, data: {},
} }
options = { options = {

+ 2
- 2
web/service/common.ts 查看文件

} from '@/models/app' } from '@/models/app'
import type { BackendModel, ProviderMap } from '@/app/components/header/account-setting/model-page/declarations' import type { BackendModel, ProviderMap } from '@/app/components/header/account-setting/model-page/declarations'


export const login: Fetcher<CommonResponse, { url: string; body: Record<string, any> }> = ({ url, body }) => {
return post<CommonResponse>(url, { body })
export const login: Fetcher<CommonResponse & { data: string }, { url: string; body: Record<string, any> }> = ({ url, body }) => {
return post(url, { body }) as Promise<CommonResponse & { data: string }>
} }


export const setup: Fetcher<CommonResponse, { body: Record<string, any> }> = ({ body }) => { export const setup: Fetcher<CommonResponse, { body: Record<string, any> }> = ({ body }) => {

Loading…
取消
儲存