| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- #
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import os
- import sys
- import logging
- from importlib.util import module_from_spec, spec_from_file_location
- from pathlib import Path
- from flask import Blueprint, Flask
- from werkzeug.wrappers.request import Request
- from flask_cors import CORS
- from flasgger import Swagger
-
- from api.db import StatusEnum
- from api.db.db_models import close_connection
- from api.db.services import UserService
- from api.utils import CustomJSONEncoder, commands
-
- from flask_session import Session
- from flask_login import LoginManager
- from api.settings import SECRET_KEY
- from api.settings import API_VERSION
- from api.utils.api_utils import server_error_response
- from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
-
- __all__ = ["app"]
-
- Request.json = property(lambda self: self.get_json(force=True, silent=True))
-
- app = Flask(__name__)
-
- # Add this at the beginning of your file to configure Swagger UI
- swagger_config = {
- "headers": [],
- "specs": [
- {
- "endpoint": "apispec",
- "route": "/apispec.json",
- "rule_filter": lambda rule: True, # Include all endpoints
- "model_filter": lambda tag: True, # Include all models
- }
- ],
- "static_url_path": "/flasgger_static",
- "swagger_ui": True,
- "specs_route": "/apidocs/",
- }
-
- swagger = Swagger(
- app,
- config=swagger_config,
- template={
- "swagger": "2.0",
- "info": {
- "title": "RAGFlow API",
- "description": "",
- "version": "1.0.0",
- },
- "securityDefinitions": {
- "ApiKeyAuth": {"type": "apiKey", "name": "Authorization", "in": "header"}
- },
- },
- )
-
- CORS(app, supports_credentials=True, max_age=2592000)
- app.url_map.strict_slashes = False
- app.json_encoder = CustomJSONEncoder
- app.errorhandler(Exception)(server_error_response)
-
-
- ## convince for dev and debug
- # app.config["LOGIN_DISABLED"] = True
- app.config["SESSION_PERMANENT"] = False
- app.config["SESSION_TYPE"] = "filesystem"
- app.config["MAX_CONTENT_LENGTH"] = int(
- os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)
- )
-
- Session(app)
- login_manager = LoginManager()
- login_manager.init_app(app)
-
- commands.register_commands(app)
-
-
- def search_pages_path(pages_dir):
- app_path_list = [
- path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".")
- ]
- api_path_list = [
- path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
- ]
- app_path_list.extend(api_path_list)
- return app_path_list
-
-
- def register_page(page_path):
- path = f"{page_path}"
-
- page_name = page_path.stem.rstrip("_app")
- module_name = ".".join(
- page_path.parts[page_path.parts.index("api") : -1] + (page_name,)
- )
-
- spec = spec_from_file_location(module_name, page_path)
- page = module_from_spec(spec)
- page.app = app
- page.manager = Blueprint(page_name, module_name)
- sys.modules[module_name] = page
- spec.loader.exec_module(page)
- page_name = getattr(page, "page_name", page_name)
- url_prefix = (
- f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}"
- )
-
- app.register_blueprint(page.manager, url_prefix=url_prefix)
- return url_prefix
-
-
- pages_dir = [
- Path(__file__).parent,
- Path(__file__).parent.parent / "api" / "apps",
- Path(__file__).parent.parent / "api" / "apps" / "sdk",
- ]
-
- client_urls_prefix = [
- register_page(path) for dir in pages_dir for path in search_pages_path(dir)
- ]
-
-
- @login_manager.request_loader
- def load_user(web_request):
- jwt = Serializer(secret_key=SECRET_KEY)
- authorization = web_request.headers.get("Authorization")
- if authorization:
- try:
- access_token = str(jwt.loads(authorization))
- user = UserService.query(
- access_token=access_token, status=StatusEnum.VALID.value
- )
- if user:
- return user[0]
- else:
- return None
- except Exception:
- logging.exception("load_user got exception")
- return None
- else:
- return None
-
-
- @app.teardown_request
- def _db_close(exc):
- close_connection()
|