| @@ -42,6 +42,15 @@ REDIS_PORT=6379 | |||
| REDIS_USERNAME= | |||
| REDIS_PASSWORD=difyai123456 | |||
| REDIS_USE_SSL=false | |||
| # SSL configuration for Redis (when REDIS_USE_SSL=true) | |||
| REDIS_SSL_CERT_REQS=CERT_NONE | |||
| # Options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED | |||
| REDIS_SSL_CA_CERTS= | |||
| # Path to CA certificate file for SSL verification | |||
| REDIS_SSL_CERTFILE= | |||
| # Path to client certificate file for SSL authentication | |||
| REDIS_SSL_KEYFILE= | |||
| # Path to client private key file for SSL authentication | |||
| REDIS_DB=0 | |||
| # redis Sentinel configuration. | |||
| @@ -39,6 +39,26 @@ class RedisConfig(BaseSettings): | |||
| default=False, | |||
| ) | |||
| REDIS_SSL_CERT_REQS: str = Field( | |||
| description="SSL certificate requirements (CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED)", | |||
| default="CERT_NONE", | |||
| ) | |||
| REDIS_SSL_CA_CERTS: Optional[str] = Field( | |||
| description="Path to the CA certificate file for SSL verification", | |||
| default=None, | |||
| ) | |||
| REDIS_SSL_CERTFILE: Optional[str] = Field( | |||
| description="Path to the client certificate file for SSL authentication", | |||
| default=None, | |||
| ) | |||
| REDIS_SSL_KEYFILE: Optional[str] = Field( | |||
| description="Path to the client private key file for SSL authentication", | |||
| default=None, | |||
| ) | |||
| REDIS_USE_SENTINEL: Optional[bool] = Field( | |||
| description="Enable Redis Sentinel mode for high availability", | |||
| default=False, | |||
| @@ -1,4 +1,6 @@ | |||
| import ssl | |||
| from datetime import timedelta | |||
| from typing import Any, Optional | |||
| import pytz | |||
| from celery import Celery, Task # type: ignore | |||
| @@ -8,6 +10,40 @@ from configs import dify_config | |||
| from dify_app import DifyApp | |||
| def _get_celery_ssl_options() -> Optional[dict[str, Any]]: | |||
| """Get SSL configuration for Celery broker/backend connections.""" | |||
| # Use REDIS_USE_SSL for consistency with the main Redis client | |||
| # Only apply SSL if we're using Redis as broker/backend | |||
| if not dify_config.REDIS_USE_SSL: | |||
| return None | |||
| # Check if Celery is actually using Redis | |||
| broker_is_redis = dify_config.CELERY_BROKER_URL and ( | |||
| dify_config.CELERY_BROKER_URL.startswith("redis://") or dify_config.CELERY_BROKER_URL.startswith("rediss://") | |||
| ) | |||
| if not broker_is_redis: | |||
| return None | |||
| # Map certificate requirement strings to SSL constants | |||
| cert_reqs_map = { | |||
| "CERT_NONE": ssl.CERT_NONE, | |||
| "CERT_OPTIONAL": ssl.CERT_OPTIONAL, | |||
| "CERT_REQUIRED": ssl.CERT_REQUIRED, | |||
| } | |||
| ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) | |||
| ssl_options = { | |||
| "ssl_cert_reqs": ssl_cert_reqs, | |||
| "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, | |||
| "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, | |||
| "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, | |||
| } | |||
| return ssl_options | |||
| def init_app(app: DifyApp) -> Celery: | |||
| class FlaskTask(Task): | |||
| def __call__(self, *args: object, **kwargs: object) -> object: | |||
| @@ -33,14 +69,6 @@ def init_app(app: DifyApp) -> Celery: | |||
| task_ignore_result=True, | |||
| ) | |||
| # Add SSL options to the Celery configuration | |||
| ssl_options = { | |||
| "ssl_cert_reqs": None, | |||
| "ssl_ca_certs": None, | |||
| "ssl_certfile": None, | |||
| "ssl_keyfile": None, | |||
| } | |||
| celery_app.conf.update( | |||
| result_backend=dify_config.CELERY_RESULT_BACKEND, | |||
| broker_transport_options=broker_transport_options, | |||
| @@ -51,9 +79,13 @@ def init_app(app: DifyApp) -> Celery: | |||
| timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"), | |||
| ) | |||
| if dify_config.BROKER_USE_SSL: | |||
| # Apply SSL configuration if enabled | |||
| ssl_options = _get_celery_ssl_options() | |||
| if ssl_options: | |||
| celery_app.conf.update( | |||
| broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration | |||
| broker_use_ssl=ssl_options, | |||
| # Also apply SSL to the backend if it's Redis | |||
| redis_backend_use_ssl=ssl_options if dify_config.CELERY_BACKEND == "redis" else None, | |||
| ) | |||
| if dify_config.LOG_FILE: | |||
| @@ -1,5 +1,6 @@ | |||
| import functools | |||
| import logging | |||
| import ssl | |||
| from collections.abc import Callable | |||
| from datetime import timedelta | |||
| from typing import TYPE_CHECKING, Any, Union | |||
| @@ -116,76 +117,132 @@ class RedisClientWrapper: | |||
| redis_client: RedisClientWrapper = RedisClientWrapper() | |||
| def init_app(app: DifyApp): | |||
| global redis_client | |||
| connection_class: type[Union[Connection, SSLConnection]] = Connection | |||
| if dify_config.REDIS_USE_SSL: | |||
| connection_class = SSLConnection | |||
| def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: | |||
| """Get SSL configuration for Redis connection.""" | |||
| if not dify_config.REDIS_USE_SSL: | |||
| return Connection, {} | |||
| cert_reqs_map = { | |||
| "CERT_NONE": ssl.CERT_NONE, | |||
| "CERT_OPTIONAL": ssl.CERT_OPTIONAL, | |||
| "CERT_REQUIRED": ssl.CERT_REQUIRED, | |||
| } | |||
| ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) | |||
| ssl_kwargs = { | |||
| "ssl_cert_reqs": ssl_cert_reqs, | |||
| "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, | |||
| "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, | |||
| "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, | |||
| } | |||
| return SSLConnection, ssl_kwargs | |||
| def _get_cache_configuration() -> CacheConfig | None: | |||
| """Get client-side cache configuration if enabled.""" | |||
| if not dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE: | |||
| return None | |||
| resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL | |||
| if dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE: | |||
| if resp_protocol >= 3: | |||
| clientside_cache_config = CacheConfig() | |||
| else: | |||
| raise ValueError("Client side cache is only supported in RESP3") | |||
| else: | |||
| clientside_cache_config = None | |||
| if resp_protocol < 3: | |||
| raise ValueError("Client side cache is only supported in RESP3") | |||
| return CacheConfig() | |||
| redis_params: dict[str, Any] = { | |||
| def _get_base_redis_params() -> dict[str, Any]: | |||
| """Get base Redis connection parameters.""" | |||
| return { | |||
| "username": dify_config.REDIS_USERNAME, | |||
| "password": dify_config.REDIS_PASSWORD or None, # Temporary fix for empty password | |||
| "password": dify_config.REDIS_PASSWORD or None, | |||
| "db": dify_config.REDIS_DB, | |||
| "encoding": "utf-8", | |||
| "encoding_errors": "strict", | |||
| "decode_responses": False, | |||
| "protocol": resp_protocol, | |||
| "cache_config": clientside_cache_config, | |||
| "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, | |||
| "cache_config": _get_cache_configuration(), | |||
| } | |||
| def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: | |||
| """Create Redis client using Sentinel configuration.""" | |||
| if not dify_config.REDIS_SENTINELS: | |||
| raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True") | |||
| if not dify_config.REDIS_SENTINEL_SERVICE_NAME: | |||
| raise ValueError("REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True") | |||
| sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")] | |||
| sentinel = Sentinel( | |||
| sentinel_hosts, | |||
| sentinel_kwargs={ | |||
| "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, | |||
| "username": dify_config.REDIS_SENTINEL_USERNAME, | |||
| "password": dify_config.REDIS_SENTINEL_PASSWORD, | |||
| }, | |||
| ) | |||
| master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) | |||
| return master | |||
| def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: | |||
| """Create Redis cluster client.""" | |||
| if not dify_config.REDIS_CLUSTERS: | |||
| raise ValueError("REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True") | |||
| nodes = [ | |||
| ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) | |||
| for node in dify_config.REDIS_CLUSTERS.split(",") | |||
| ] | |||
| cluster: RedisCluster = RedisCluster( | |||
| startup_nodes=nodes, | |||
| password=dify_config.REDIS_CLUSTERS_PASSWORD, | |||
| protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL, | |||
| cache_config=_get_cache_configuration(), | |||
| ) | |||
| return cluster | |||
| def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: | |||
| """Create standalone Redis client.""" | |||
| connection_class, ssl_kwargs = _get_ssl_configuration() | |||
| redis_params.update( | |||
| { | |||
| "host": dify_config.REDIS_HOST, | |||
| "port": dify_config.REDIS_PORT, | |||
| "connection_class": connection_class, | |||
| } | |||
| ) | |||
| if ssl_kwargs: | |||
| redis_params.update(ssl_kwargs) | |||
| pool = redis.ConnectionPool(**redis_params) | |||
| client: redis.Redis = redis.Redis(connection_pool=pool) | |||
| return client | |||
| def init_app(app: DifyApp): | |||
| """Initialize Redis client and attach it to the app.""" | |||
| global redis_client | |||
| # Determine Redis mode and create appropriate client | |||
| if dify_config.REDIS_USE_SENTINEL: | |||
| assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True" | |||
| assert dify_config.REDIS_SENTINEL_SERVICE_NAME is not None, ( | |||
| "REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True" | |||
| ) | |||
| sentinel_hosts = [ | |||
| (node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",") | |||
| ] | |||
| sentinel = Sentinel( | |||
| sentinel_hosts, | |||
| sentinel_kwargs={ | |||
| "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, | |||
| "username": dify_config.REDIS_SENTINEL_USERNAME, | |||
| "password": dify_config.REDIS_SENTINEL_PASSWORD, | |||
| }, | |||
| ) | |||
| master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) | |||
| redis_client.initialize(master) | |||
| redis_params = _get_base_redis_params() | |||
| client = _create_sentinel_client(redis_params) | |||
| elif dify_config.REDIS_USE_CLUSTERS: | |||
| assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True" | |||
| nodes = [ | |||
| ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) | |||
| for node in dify_config.REDIS_CLUSTERS.split(",") | |||
| ] | |||
| redis_client.initialize( | |||
| RedisCluster( | |||
| startup_nodes=nodes, | |||
| password=dify_config.REDIS_CLUSTERS_PASSWORD, | |||
| protocol=resp_protocol, | |||
| cache_config=clientside_cache_config, | |||
| ) | |||
| ) | |||
| client = _create_cluster_client() | |||
| else: | |||
| redis_params.update( | |||
| { | |||
| "host": dify_config.REDIS_HOST, | |||
| "port": dify_config.REDIS_PORT, | |||
| "connection_class": connection_class, | |||
| "protocol": resp_protocol, | |||
| "cache_config": clientside_cache_config, | |||
| } | |||
| ) | |||
| pool = redis.ConnectionPool(**redis_params) | |||
| redis_client.initialize(redis.Redis(connection_pool=pool)) | |||
| redis_params = _get_base_redis_params() | |||
| client = _create_standalone_client(redis_params) | |||
| # Initialize the wrapper and attach to app | |||
| redis_client.initialize(client) | |||
| app.extensions["redis"] = redis_client | |||
| @@ -0,0 +1,149 @@ | |||
| """Tests for Celery SSL configuration.""" | |||
| import ssl | |||
| from unittest.mock import MagicMock, patch | |||
| class TestCelerySSLConfiguration: | |||
| """Test suite for Celery SSL configuration.""" | |||
| def test_get_celery_ssl_options_when_ssl_disabled(self): | |||
| """Test SSL options when REDIS_USE_SSL is False.""" | |||
| mock_config = MagicMock() | |||
| mock_config.REDIS_USE_SSL = False | |||
| with patch("extensions.ext_celery.dify_config", mock_config): | |||
| from extensions.ext_celery import _get_celery_ssl_options | |||
| result = _get_celery_ssl_options() | |||
| assert result is None | |||
| def test_get_celery_ssl_options_when_broker_not_redis(self): | |||
| """Test SSL options when broker is not Redis.""" | |||
| mock_config = MagicMock() | |||
| mock_config.REDIS_USE_SSL = True | |||
| mock_config.CELERY_BROKER_URL = "amqp://localhost:5672" | |||
| with patch("extensions.ext_celery.dify_config", mock_config): | |||
| from extensions.ext_celery import _get_celery_ssl_options | |||
| result = _get_celery_ssl_options() | |||
| assert result is None | |||
| def test_get_celery_ssl_options_with_cert_none(self): | |||
| """Test SSL options with CERT_NONE requirement.""" | |||
| mock_config = MagicMock() | |||
| mock_config.REDIS_USE_SSL = True | |||
| mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" | |||
| mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE" | |||
| mock_config.REDIS_SSL_CA_CERTS = None | |||
| mock_config.REDIS_SSL_CERTFILE = None | |||
| mock_config.REDIS_SSL_KEYFILE = None | |||
| with patch("extensions.ext_celery.dify_config", mock_config): | |||
| from extensions.ext_celery import _get_celery_ssl_options | |||
| result = _get_celery_ssl_options() | |||
| assert result is not None | |||
| assert result["ssl_cert_reqs"] == ssl.CERT_NONE | |||
| assert result["ssl_ca_certs"] is None | |||
| assert result["ssl_certfile"] is None | |||
| assert result["ssl_keyfile"] is None | |||
| def test_get_celery_ssl_options_with_cert_required(self): | |||
| """Test SSL options with CERT_REQUIRED and certificates.""" | |||
| mock_config = MagicMock() | |||
| mock_config.REDIS_USE_SSL = True | |||
| mock_config.CELERY_BROKER_URL = "rediss://localhost:6380/0" | |||
| mock_config.REDIS_SSL_CERT_REQS = "CERT_REQUIRED" | |||
| mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" | |||
| mock_config.REDIS_SSL_CERTFILE = "/path/to/client.crt" | |||
| mock_config.REDIS_SSL_KEYFILE = "/path/to/client.key" | |||
| with patch("extensions.ext_celery.dify_config", mock_config): | |||
| from extensions.ext_celery import _get_celery_ssl_options | |||
| result = _get_celery_ssl_options() | |||
| assert result is not None | |||
| assert result["ssl_cert_reqs"] == ssl.CERT_REQUIRED | |||
| assert result["ssl_ca_certs"] == "/path/to/ca.crt" | |||
| assert result["ssl_certfile"] == "/path/to/client.crt" | |||
| assert result["ssl_keyfile"] == "/path/to/client.key" | |||
| def test_get_celery_ssl_options_with_cert_optional(self): | |||
| """Test SSL options with CERT_OPTIONAL requirement.""" | |||
| mock_config = MagicMock() | |||
| mock_config.REDIS_USE_SSL = True | |||
| mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" | |||
| mock_config.REDIS_SSL_CERT_REQS = "CERT_OPTIONAL" | |||
| mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" | |||
| mock_config.REDIS_SSL_CERTFILE = None | |||
| mock_config.REDIS_SSL_KEYFILE = None | |||
| with patch("extensions.ext_celery.dify_config", mock_config): | |||
| from extensions.ext_celery import _get_celery_ssl_options | |||
| result = _get_celery_ssl_options() | |||
| assert result is not None | |||
| assert result["ssl_cert_reqs"] == ssl.CERT_OPTIONAL | |||
| assert result["ssl_ca_certs"] == "/path/to/ca.crt" | |||
| def test_get_celery_ssl_options_with_invalid_cert_reqs(self): | |||
| """Test SSL options with invalid cert requirement defaults to CERT_NONE.""" | |||
| mock_config = MagicMock() | |||
| mock_config.REDIS_USE_SSL = True | |||
| mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" | |||
| mock_config.REDIS_SSL_CERT_REQS = "INVALID_VALUE" | |||
| mock_config.REDIS_SSL_CA_CERTS = None | |||
| mock_config.REDIS_SSL_CERTFILE = None | |||
| mock_config.REDIS_SSL_KEYFILE = None | |||
| with patch("extensions.ext_celery.dify_config", mock_config): | |||
| from extensions.ext_celery import _get_celery_ssl_options | |||
| result = _get_celery_ssl_options() | |||
| assert result is not None | |||
| assert result["ssl_cert_reqs"] == ssl.CERT_NONE # Should default to CERT_NONE | |||
| def test_celery_init_applies_ssl_to_broker_and_backend(self): | |||
| """Test that SSL options are applied to both broker and backend when using Redis.""" | |||
| mock_config = MagicMock() | |||
| mock_config.REDIS_USE_SSL = True | |||
| mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" | |||
| mock_config.CELERY_BACKEND = "redis" | |||
| mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" | |||
| mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE" | |||
| mock_config.REDIS_SSL_CA_CERTS = None | |||
| mock_config.REDIS_SSL_CERTFILE = None | |||
| mock_config.REDIS_SSL_KEYFILE = None | |||
| mock_config.CELERY_USE_SENTINEL = False | |||
| mock_config.LOG_FORMAT = "%(message)s" | |||
| mock_config.LOG_TZ = "UTC" | |||
| mock_config.LOG_FILE = None | |||
| # Mock all the scheduler configs | |||
| mock_config.CELERY_BEAT_SCHEDULER_TIME = 1 | |||
| mock_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK = False | |||
| mock_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK = False | |||
| mock_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK = False | |||
| mock_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK = False | |||
| mock_config.ENABLE_CLEAN_MESSAGES = False | |||
| mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False | |||
| mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False | |||
| mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False | |||
| with patch("extensions.ext_celery.dify_config", mock_config): | |||
| from dify_app import DifyApp | |||
| from extensions.ext_celery import init_app | |||
| app = DifyApp(__name__) | |||
| celery_app = init_app(app) | |||
| # Check that SSL options were applied | |||
| assert "broker_use_ssl" in celery_app.conf | |||
| assert celery_app.conf["broker_use_ssl"] is not None | |||
| assert celery_app.conf["broker_use_ssl"]["ssl_cert_reqs"] == ssl.CERT_NONE | |||
| # Check that SSL is also applied to Redis backend | |||
| assert "redis_backend_use_ssl" in celery_app.conf | |||
| assert celery_app.conf["redis_backend_use_ssl"] is not None | |||
| @@ -264,6 +264,15 @@ REDIS_PORT=6379 | |||
| REDIS_USERNAME= | |||
| REDIS_PASSWORD=difyai123456 | |||
| REDIS_USE_SSL=false | |||
| # SSL configuration for Redis (when REDIS_USE_SSL=true) | |||
| REDIS_SSL_CERT_REQS=CERT_NONE | |||
| # Options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED | |||
| REDIS_SSL_CA_CERTS= | |||
| # Path to CA certificate file for SSL verification | |||
| REDIS_SSL_CERTFILE= | |||
| # Path to client certificate file for SSL authentication | |||
| REDIS_SSL_KEYFILE= | |||
| # Path to client private key file for SSL authentication | |||
| REDIS_DB=0 | |||
| # Whether to use Redis Sentinel mode. | |||
| @@ -71,6 +71,10 @@ x-shared-env: &shared-api-worker-env | |||
| REDIS_USERNAME: ${REDIS_USERNAME:-} | |||
| REDIS_PASSWORD: ${REDIS_PASSWORD:-difyai123456} | |||
| REDIS_USE_SSL: ${REDIS_USE_SSL:-false} | |||
| REDIS_SSL_CERT_REQS: ${REDIS_SSL_CERT_REQS:-CERT_NONE} | |||
| REDIS_SSL_CA_CERTS: ${REDIS_SSL_CA_CERTS:-} | |||
| REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-} | |||
| REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-} | |||
| REDIS_DB: ${REDIS_DB:-0} | |||
| REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false} | |||
| REDIS_SENTINELS: ${REDIS_SENTINELS:-} | |||