You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. #
  2. # Copyright 2019 The RAG Flow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import sys
  18. from importlib.util import module_from_spec, spec_from_file_location
  19. from pathlib import Path
  20. from flask import Blueprint, Flask, request
  21. from werkzeug.wrappers.request import Request
  22. from flask_cors import CORS
  23. from web_server.db import StatusEnum
  24. from web_server.db.services import UserService
  25. from web_server.utils import CustomJSONEncoder
  26. from flask_session import Session
  27. from flask_login import LoginManager
  28. from web_server.settings import RetCode, SECRET_KEY, stat_logger
  29. from web_server.hook import HookManager
  30. from web_server.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters
  31. from web_server.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger
  32. from web_server.utils.api_utils import get_json_result, server_error_response
  33. from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
  34. __all__ = ['app']
  35. logger = logging.getLogger('flask.app')
  36. for h in access_logger.handlers:
  37. logger.addHandler(h)
  38. Request.json = property(lambda self: self.get_json(force=True, silent=True))
  39. app = Flask(__name__)
  40. CORS(app, supports_credentials=True,max_age = 2592000)
  41. app.url_map.strict_slashes = False
  42. app.json_encoder = CustomJSONEncoder
  43. app.errorhandler(Exception)(server_error_response)
  44. ## convince for dev and debug
  45. #app.config["LOGIN_DISABLED"] = True
  46. app.config["SESSION_PERMANENT"] = False
  47. app.config["SESSION_TYPE"] = "filesystem"
  48. app.config['MAX_CONTENT_LENGTH'] = 64 * 1024 * 1024
  49. Session(app)
  50. login_manager = LoginManager()
  51. login_manager.init_app(app)
  52. def search_pages_path(pages_dir):
  53. return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]
  54. def register_page(page_path):
  55. page_name = page_path.stem.rstrip('_app')
  56. module_name = '.'.join(page_path.parts[page_path.parts.index('web_server'):-1] + (page_name, ))
  57. spec = spec_from_file_location(module_name, page_path)
  58. page = module_from_spec(spec)
  59. page.app = app
  60. page.manager = Blueprint(page_name, module_name)
  61. sys.modules[module_name] = page
  62. spec.loader.exec_module(page)
  63. page_name = getattr(page, 'page_name', page_name)
  64. url_prefix = f'/{API_VERSION}/{page_name}'
  65. app.register_blueprint(page.manager, url_prefix=url_prefix)
  66. return url_prefix
  67. pages_dir = [
  68. Path(__file__).parent,
  69. Path(__file__).parent.parent / 'web_server' / 'apps',
  70. ]
  71. client_urls_prefix = [
  72. register_page(path)
  73. for dir in pages_dir
  74. for path in search_pages_path(dir)
  75. ]
  76. def client_authentication_before_request():
  77. result = HookManager.client_authentication(ClientAuthenticationParameters(
  78. request.full_path, request.headers,
  79. request.form, request.data, request.json,
  80. ))
  81. if result.code != RetCode.SUCCESS:
  82. return get_json_result(result.code, result.message)
  83. def site_authentication_before_request():
  84. for url_prefix in client_urls_prefix:
  85. if request.path.startswith(url_prefix):
  86. return
  87. result = HookManager.site_authentication(AuthenticationParameters(
  88. request.headers.get('site_signature'),
  89. request.json,
  90. ))
  91. if result.code != RetCode.SUCCESS:
  92. return get_json_result(result.code, result.message)
  93. @app.before_request
  94. def authentication_before_request():
  95. if CLIENT_AUTHENTICATION:
  96. return client_authentication_before_request()
  97. if SITE_AUTHENTICATION:
  98. return site_authentication_before_request()
  99. @login_manager.request_loader
  100. def load_user(web_request):
  101. jwt = Serializer(secret_key=SECRET_KEY)
  102. authorization = web_request.headers.get("Authorization")
  103. if authorization:
  104. try:
  105. access_token = str(jwt.loads(authorization))
  106. user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value)
  107. if user:
  108. return user[0]
  109. else:
  110. return None
  111. except Exception as e:
  112. stat_logger.exception(e)
  113. return None
  114. else:
  115. return None