| @@ -10,44 +10,19 @@ if os.environ.get("DEBUG", "false").lower() != "true": | |||
| grpc.experimental.gevent.init_gevent() | |||
| import json | |||
| import logging | |||
| import sys | |||
| import threading | |||
| import time | |||
| import warnings | |||
| from logging.handlers import RotatingFileHandler | |||
| from flask import Flask, Response, request | |||
| from flask_cors import CORS | |||
| from werkzeug.exceptions import Unauthorized | |||
| from flask import Response | |||
| import contexts | |||
| from commands import register_commands | |||
| from configs import dify_config | |||
| from app_factory import create_app | |||
| # DO NOT REMOVE BELOW | |||
| from events import event_handlers # noqa: F401 | |||
| from extensions import ( | |||
| ext_celery, | |||
| ext_code_based_extension, | |||
| ext_compress, | |||
| ext_database, | |||
| ext_hosting_provider, | |||
| ext_login, | |||
| ext_mail, | |||
| ext_migrate, | |||
| ext_proxy_fix, | |||
| ext_redis, | |||
| ext_sentry, | |||
| ext_storage, | |||
| ) | |||
| from extensions.ext_database import db | |||
| from extensions.ext_login import login_manager | |||
| from libs.passport import PassportService | |||
| # TODO: Find a way to avoid importing models here | |||
| from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 | |||
| from services.account_service import AccountService | |||
| # DO NOT REMOVE ABOVE | |||
| @@ -60,188 +35,12 @@ if hasattr(time, "tzset"): | |||
| time.tzset() | |||
| class DifyApp(Flask): | |||
| pass | |||
| # ------------- | |||
| # Configuration | |||
| # ------------- | |||
| config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first | |||
| # ---------------------------- | |||
| # Application Factory Function | |||
| # ---------------------------- | |||
| def create_flask_app_with_configs() -> Flask: | |||
| """ | |||
| create a raw flask app | |||
| with configs loaded from .env file | |||
| """ | |||
| dify_app = DifyApp(__name__) | |||
| dify_app.config.from_mapping(dify_config.model_dump()) | |||
| # populate configs into system environment variables | |||
| for key, value in dify_app.config.items(): | |||
| if isinstance(value, str): | |||
| os.environ[key] = value | |||
| elif isinstance(value, int | float | bool): | |||
| os.environ[key] = str(value) | |||
| elif value is None: | |||
| os.environ[key] = "" | |||
| return dify_app | |||
| def create_app() -> Flask: | |||
| app = create_flask_app_with_configs() | |||
| app.secret_key = app.config["SECRET_KEY"] | |||
| log_handlers = None | |||
| log_file = app.config.get("LOG_FILE") | |||
| if log_file: | |||
| log_dir = os.path.dirname(log_file) | |||
| os.makedirs(log_dir, exist_ok=True) | |||
| log_handlers = [ | |||
| RotatingFileHandler( | |||
| filename=log_file, | |||
| maxBytes=1024 * 1024 * 1024, | |||
| backupCount=5, | |||
| ), | |||
| logging.StreamHandler(sys.stdout), | |||
| ] | |||
| logging.basicConfig( | |||
| level=app.config.get("LOG_LEVEL"), | |||
| format=app.config.get("LOG_FORMAT"), | |||
| datefmt=app.config.get("LOG_DATEFORMAT"), | |||
| handlers=log_handlers, | |||
| force=True, | |||
| ) | |||
| log_tz = app.config.get("LOG_TZ") | |||
| if log_tz: | |||
| from datetime import datetime | |||
| import pytz | |||
| timezone = pytz.timezone(log_tz) | |||
| def time_converter(seconds): | |||
| return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() | |||
| for handler in logging.root.handlers: | |||
| handler.formatter.converter = time_converter | |||
| initialize_extensions(app) | |||
| register_blueprints(app) | |||
| register_commands(app) | |||
| return app | |||
| def initialize_extensions(app): | |||
| # Since the application instance is now created, pass it to each Flask | |||
| # extension instance to bind it to the Flask application instance (app) | |||
| ext_compress.init_app(app) | |||
| ext_code_based_extension.init() | |||
| ext_database.init_app(app) | |||
| ext_migrate.init(app, db) | |||
| ext_redis.init_app(app) | |||
| ext_storage.init_app(app) | |||
| ext_celery.init_app(app) | |||
| ext_login.init_app(app) | |||
| ext_mail.init_app(app) | |||
| ext_hosting_provider.init_app(app) | |||
| ext_sentry.init_app(app) | |||
| ext_proxy_fix.init_app(app) | |||
| # Flask-Login configuration | |||
| @login_manager.request_loader | |||
| def load_user_from_request(request_from_flask_login): | |||
| """Load user based on the request.""" | |||
| if request.blueprint not in {"console", "inner_api"}: | |||
| return None | |||
| # Check if the user_id contains a dot, indicating the old format | |||
| auth_header = request.headers.get("Authorization", "") | |||
| if not auth_header: | |||
| auth_token = request.args.get("_token") | |||
| if not auth_token: | |||
| raise Unauthorized("Invalid Authorization token.") | |||
| else: | |||
| if " " not in auth_header: | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| auth_scheme, auth_token = auth_header.split(None, 1) | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != "bearer": | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| decoded = PassportService().verify(auth_token) | |||
| user_id = decoded.get("user_id") | |||
| logged_in_account = AccountService.load_logged_in_account(account_id=user_id) | |||
| if logged_in_account: | |||
| contexts.tenant_id.set(logged_in_account.current_tenant_id) | |||
| return logged_in_account | |||
| @login_manager.unauthorized_handler | |||
| def unauthorized_handler(): | |||
| """Handle unauthorized requests.""" | |||
| return Response( | |||
| json.dumps({"code": "unauthorized", "message": "Unauthorized."}), | |||
| status=401, | |||
| content_type="application/json", | |||
| ) | |||
| # register blueprint routers | |||
| def register_blueprints(app): | |||
| from controllers.console import bp as console_app_bp | |||
| from controllers.files import bp as files_bp | |||
| from controllers.inner_api import bp as inner_api_bp | |||
| from controllers.service_api import bp as service_api_bp | |||
| from controllers.web import bp as web_bp | |||
| CORS( | |||
| service_api_bp, | |||
| allow_headers=["Content-Type", "Authorization", "X-App-Code"], | |||
| methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], | |||
| ) | |||
| app.register_blueprint(service_api_bp) | |||
| CORS( | |||
| web_bp, | |||
| resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, | |||
| supports_credentials=True, | |||
| allow_headers=["Content-Type", "Authorization", "X-App-Code"], | |||
| methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], | |||
| expose_headers=["X-Version", "X-Env"], | |||
| ) | |||
| app.register_blueprint(web_bp) | |||
| CORS( | |||
| console_app_bp, | |||
| resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, | |||
| supports_credentials=True, | |||
| allow_headers=["Content-Type", "Authorization"], | |||
| methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], | |||
| expose_headers=["X-Version", "X-Env"], | |||
| ) | |||
| app.register_blueprint(console_app_bp) | |||
| CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) | |||
| app.register_blueprint(files_bp) | |||
| app.register_blueprint(inner_api_bp) | |||
| # create app | |||
| app = create_app() | |||
| celery = app.extensions["celery"] | |||
| @@ -0,0 +1,213 @@ | |||
| import os | |||
| if os.environ.get("DEBUG", "false").lower() != "true": | |||
| from gevent import monkey | |||
| monkey.patch_all() | |||
| import grpc.experimental.gevent | |||
| grpc.experimental.gevent.init_gevent() | |||
| import json | |||
| import logging | |||
| import sys | |||
| from logging.handlers import RotatingFileHandler | |||
| from flask import Flask, Response, request | |||
| from flask_cors import CORS | |||
| from werkzeug.exceptions import Unauthorized | |||
| import contexts | |||
| from commands import register_commands | |||
| from configs import dify_config | |||
| from extensions import ( | |||
| ext_celery, | |||
| ext_code_based_extension, | |||
| ext_compress, | |||
| ext_database, | |||
| ext_hosting_provider, | |||
| ext_login, | |||
| ext_mail, | |||
| ext_migrate, | |||
| ext_proxy_fix, | |||
| ext_redis, | |||
| ext_sentry, | |||
| ext_storage, | |||
| ) | |||
| from extensions.ext_database import db | |||
| from extensions.ext_login import login_manager | |||
| from libs.passport import PassportService | |||
| from services.account_service import AccountService | |||
| class DifyApp(Flask): | |||
| pass | |||
| # ---------------------------- | |||
| # Application Factory Function | |||
| # ---------------------------- | |||
| def create_flask_app_with_configs() -> Flask: | |||
| """ | |||
| create a raw flask app | |||
| with configs loaded from .env file | |||
| """ | |||
| dify_app = DifyApp(__name__) | |||
| dify_app.config.from_mapping(dify_config.model_dump()) | |||
| # populate configs into system environment variables | |||
| for key, value in dify_app.config.items(): | |||
| if isinstance(value, str): | |||
| os.environ[key] = value | |||
| elif isinstance(value, int | float | bool): | |||
| os.environ[key] = str(value) | |||
| elif value is None: | |||
| os.environ[key] = "" | |||
| return dify_app | |||
| def create_app() -> Flask: | |||
| app = create_flask_app_with_configs() | |||
| app.secret_key = app.config["SECRET_KEY"] | |||
| log_handlers = None | |||
| log_file = app.config.get("LOG_FILE") | |||
| if log_file: | |||
| log_dir = os.path.dirname(log_file) | |||
| os.makedirs(log_dir, exist_ok=True) | |||
| log_handlers = [ | |||
| RotatingFileHandler( | |||
| filename=log_file, | |||
| maxBytes=1024 * 1024 * 1024, | |||
| backupCount=5, | |||
| ), | |||
| logging.StreamHandler(sys.stdout), | |||
| ] | |||
| logging.basicConfig( | |||
| level=app.config.get("LOG_LEVEL"), | |||
| format=app.config.get("LOG_FORMAT"), | |||
| datefmt=app.config.get("LOG_DATEFORMAT"), | |||
| handlers=log_handlers, | |||
| force=True, | |||
| ) | |||
| log_tz = app.config.get("LOG_TZ") | |||
| if log_tz: | |||
| from datetime import datetime | |||
| import pytz | |||
| timezone = pytz.timezone(log_tz) | |||
| def time_converter(seconds): | |||
| return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() | |||
| for handler in logging.root.handlers: | |||
| handler.formatter.converter = time_converter | |||
| initialize_extensions(app) | |||
| register_blueprints(app) | |||
| register_commands(app) | |||
| return app | |||
| def initialize_extensions(app): | |||
| # Since the application instance is now created, pass it to each Flask | |||
| # extension instance to bind it to the Flask application instance (app) | |||
| ext_compress.init_app(app) | |||
| ext_code_based_extension.init() | |||
| ext_database.init_app(app) | |||
| ext_migrate.init(app, db) | |||
| ext_redis.init_app(app) | |||
| ext_storage.init_app(app) | |||
| ext_celery.init_app(app) | |||
| ext_login.init_app(app) | |||
| ext_mail.init_app(app) | |||
| ext_hosting_provider.init_app(app) | |||
| ext_sentry.init_app(app) | |||
| ext_proxy_fix.init_app(app) | |||
| # Flask-Login configuration | |||
| @login_manager.request_loader | |||
| def load_user_from_request(request_from_flask_login): | |||
| """Load user based on the request.""" | |||
| if request.blueprint not in {"console", "inner_api"}: | |||
| return None | |||
| # Check if the user_id contains a dot, indicating the old format | |||
| auth_header = request.headers.get("Authorization", "") | |||
| if not auth_header: | |||
| auth_token = request.args.get("_token") | |||
| if not auth_token: | |||
| raise Unauthorized("Invalid Authorization token.") | |||
| else: | |||
| if " " not in auth_header: | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| auth_scheme, auth_token = auth_header.split(None, 1) | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != "bearer": | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| decoded = PassportService().verify(auth_token) | |||
| user_id = decoded.get("user_id") | |||
| logged_in_account = AccountService.load_logged_in_account(account_id=user_id) | |||
| if logged_in_account: | |||
| contexts.tenant_id.set(logged_in_account.current_tenant_id) | |||
| return logged_in_account | |||
| @login_manager.unauthorized_handler | |||
| def unauthorized_handler(): | |||
| """Handle unauthorized requests.""" | |||
| return Response( | |||
| json.dumps({"code": "unauthorized", "message": "Unauthorized."}), | |||
| status=401, | |||
| content_type="application/json", | |||
| ) | |||
| # register blueprint routers | |||
| def register_blueprints(app): | |||
| from controllers.console import bp as console_app_bp | |||
| from controllers.files import bp as files_bp | |||
| from controllers.inner_api import bp as inner_api_bp | |||
| from controllers.service_api import bp as service_api_bp | |||
| from controllers.web import bp as web_bp | |||
| CORS( | |||
| service_api_bp, | |||
| allow_headers=["Content-Type", "Authorization", "X-App-Code"], | |||
| methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], | |||
| ) | |||
| app.register_blueprint(service_api_bp) | |||
| CORS( | |||
| web_bp, | |||
| resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, | |||
| supports_credentials=True, | |||
| allow_headers=["Content-Type", "Authorization", "X-App-Code"], | |||
| methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], | |||
| expose_headers=["X-Version", "X-Env"], | |||
| ) | |||
| app.register_blueprint(web_bp) | |||
| CORS( | |||
| console_app_bp, | |||
| resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, | |||
| supports_credentials=True, | |||
| allow_headers=["Content-Type", "Authorization"], | |||
| methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], | |||
| expose_headers=["X-Version", "X-Env"], | |||
| ) | |||
| app.register_blueprint(console_app_bp) | |||
| CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) | |||
| app.register_blueprint(files_bp) | |||
| app.register_blueprint(inner_api_bp) | |||
| @@ -0,0 +1,24 @@ | |||
| import pytest | |||
| from app_factory import create_app | |||
| mock_user = type( | |||
| "MockUser", | |||
| (object,), | |||
| { | |||
| "is_authenticated": True, | |||
| "id": "123", | |||
| "is_editor": True, | |||
| "is_dataset_editor": True, | |||
| "status": "active", | |||
| "get_id": "123", | |||
| "current_tenant_id": "9d2074fc-6f86-45a9-b09d-6ecc63b9056b", | |||
| }, | |||
| ) | |||
| @pytest.fixture | |||
| def app(): | |||
| app = create_app() | |||
| app.config["LOGIN_DISABLED"] = True | |||
| return app | |||
| @@ -0,0 +1,10 @@ | |||
| from unittest.mock import patch | |||
| from app_fixture import app, mock_user | |||
| def test_post_requires_login(app): | |||
| with app.test_client() as client: | |||
| with patch("flask_login.utils._get_user", mock_user): | |||
| response = client.get("/console/api/data-source/integrates") | |||
| assert response.status_code == 200 | |||