Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.5.0
| # Next.js build output | # Next.js build output | ||||
| .next/ | .next/ | ||||
| # AI Assistant | |||||
| .roo/ |
| from collections.abc import Generator, Mapping | from collections.abc import Generator, Mapping | ||||
| from typing import Any, Literal, Optional, Union, overload | from typing import Any, Literal, Optional, Union, overload | ||||
| from flask import Flask, copy_current_request_context, current_app, has_request_context | |||||
| from flask import Flask, current_app | |||||
| from pydantic import ValidationError | from pydantic import ValidationError | ||||
| from sqlalchemy.orm import sessionmaker | from sqlalchemy.orm import sessionmaker | ||||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| from libs.flask_utils import preserve_flask_contexts | |||||
| from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom | from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom | ||||
| from models.enums import WorkflowRunTriggeredFrom | from models.enums import WorkflowRunTriggeredFrom | ||||
| from services.conversation_service import ConversationService | from services.conversation_service import ConversationService | ||||
| # new thread with request context and contextvars | # new thread with request context and contextvars | ||||
| context = contextvars.copy_context() | context = contextvars.copy_context() | ||||
| @copy_current_request_context | |||||
| def worker_with_context(): | |||||
| # Run the worker within the copied context | |||||
| return context.run( | |||||
| self._generate_worker, | |||||
| flask_app=current_app._get_current_object(), # type: ignore | |||||
| application_generate_entity=application_generate_entity, | |||||
| queue_manager=queue_manager, | |||||
| conversation_id=conversation.id, | |||||
| message_id=message.id, | |||||
| context=context, | |||||
| ) | |||||
| worker_thread = threading.Thread(target=worker_with_context) | |||||
| worker_thread = threading.Thread( | |||||
| target=self._generate_worker, | |||||
| kwargs={ | |||||
| "flask_app": current_app._get_current_object(), # type: ignore | |||||
| "application_generate_entity": application_generate_entity, | |||||
| "queue_manager": queue_manager, | |||||
| "conversation_id": conversation.id, | |||||
| "message_id": message.id, | |||||
| "context": context, | |||||
| }, | |||||
| ) | |||||
| worker_thread.start() | worker_thread.start() | ||||
| :param message_id: message ID | :param message_id: message ID | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| for var, val in context.items(): | |||||
| var.set(val) | |||||
| # FIXME(-LAN-): Save current user before entering new app context | |||||
| from flask import g | |||||
| saved_user = None | |||||
| if has_request_context() and hasattr(g, "_login_user"): | |||||
| saved_user = g._login_user | |||||
| with flask_app.app_context(): | |||||
| with preserve_flask_contexts(flask_app, context_vars=context): | |||||
| try: | try: | ||||
| # Restore user in new app context | |||||
| if saved_user is not None: | |||||
| from flask import g | |||||
| g._login_user = saved_user | |||||
| # get conversation and message | # get conversation and message | ||||
| conversation = self._get_conversation(conversation_id) | conversation = self._get_conversation(conversation_id) | ||||
| message = self._get_message(message_id) | message = self._get_message(message_id) |
| from collections.abc import Generator, Mapping | from collections.abc import Generator, Mapping | ||||
| from typing import Any, Literal, Union, overload | from typing import Any, Literal, Union, overload | ||||
| from flask import Flask, copy_current_request_context, current_app, has_request_context | |||||
| from flask import Flask, current_app | |||||
| from pydantic import ValidationError | from pydantic import ValidationError | ||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| from libs.flask_utils import preserve_flask_contexts | |||||
| from models import Account, App, EndUser | from models import Account, App, EndUser | ||||
| from services.conversation_service import ConversationService | from services.conversation_service import ConversationService | ||||
| from services.errors.message import MessageNotExistsError | from services.errors.message import MessageNotExistsError | ||||
| # new thread with request context and contextvars | # new thread with request context and contextvars | ||||
| context = contextvars.copy_context() | context = contextvars.copy_context() | ||||
| @copy_current_request_context | |||||
| def worker_with_context(): | |||||
| # Run the worker within the copied context | |||||
| return context.run( | |||||
| self._generate_worker, | |||||
| flask_app=current_app._get_current_object(), # type: ignore | |||||
| context=context, | |||||
| application_generate_entity=application_generate_entity, | |||||
| queue_manager=queue_manager, | |||||
| conversation_id=conversation.id, | |||||
| message_id=message.id, | |||||
| ) | |||||
| worker_thread = threading.Thread(target=worker_with_context) | |||||
| worker_thread = threading.Thread( | |||||
| target=self._generate_worker, | |||||
| kwargs={ | |||||
| "flask_app": current_app._get_current_object(), # type: ignore | |||||
| "context": context, | |||||
| "application_generate_entity": application_generate_entity, | |||||
| "queue_manager": queue_manager, | |||||
| "conversation_id": conversation.id, | |||||
| "message_id": message.id, | |||||
| }, | |||||
| ) | |||||
| worker_thread.start() | worker_thread.start() | ||||
| :param message_id: message ID | :param message_id: message ID | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| for var, val in context.items(): | |||||
| var.set(val) | |||||
| # FIXME(-LAN-): Save current user before entering new app context | |||||
| from flask import g | |||||
| saved_user = None | |||||
| if has_request_context() and hasattr(g, "_login_user"): | |||||
| saved_user = g._login_user | |||||
| with flask_app.app_context(): | |||||
| with preserve_flask_contexts(flask_app, context_vars=context): | |||||
| try: | try: | ||||
| # Restore user in new app context | |||||
| if saved_user is not None: | |||||
| from flask import g | |||||
| g._login_user = saved_user | |||||
| # get conversation and message | # get conversation and message | ||||
| conversation = self._get_conversation(conversation_id) | conversation = self._get_conversation(conversation_id) | ||||
| message = self._get_message(message_id) | message = self._get_message(message_id) |
| from collections.abc import Generator, Mapping, Sequence | from collections.abc import Generator, Mapping, Sequence | ||||
| from typing import Any, Literal, Optional, Union, overload | from typing import Any, Literal, Optional, Union, overload | ||||
| from flask import Flask, copy_current_request_context, current_app, has_request_context | |||||
| from flask import Flask, current_app | |||||
| from pydantic import ValidationError | from pydantic import ValidationError | ||||
| from sqlalchemy.orm import sessionmaker | from sqlalchemy.orm import sessionmaker | ||||
| from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| from libs.flask_utils import preserve_flask_contexts | |||||
| from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom | from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom | ||||
| from models.enums import WorkflowRunTriggeredFrom | from models.enums import WorkflowRunTriggeredFrom | ||||
| # new thread with request context and contextvars | # new thread with request context and contextvars | ||||
| context = contextvars.copy_context() | context = contextvars.copy_context() | ||||
| @copy_current_request_context | |||||
| def worker_with_context(): | |||||
| # Run the worker within the copied context | |||||
| return context.run( | |||||
| self._generate_worker, | |||||
| flask_app=current_app._get_current_object(), # type: ignore | |||||
| application_generate_entity=application_generate_entity, | |||||
| queue_manager=queue_manager, | |||||
| context=context, | |||||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||||
| ) | |||||
| worker_thread = threading.Thread(target=worker_with_context) | |||||
| worker_thread = threading.Thread( | |||||
| target=self._generate_worker, | |||||
| kwargs={ | |||||
| "flask_app": current_app._get_current_object(), # type: ignore | |||||
| "application_generate_entity": application_generate_entity, | |||||
| "queue_manager": queue_manager, | |||||
| "context": context, | |||||
| "workflow_thread_pool_id": workflow_thread_pool_id, | |||||
| }, | |||||
| ) | |||||
| worker_thread.start() | worker_thread.start() | ||||
| :param workflow_thread_pool_id: workflow thread pool id | :param workflow_thread_pool_id: workflow thread pool id | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| for var, val in context.items(): | |||||
| var.set(val) | |||||
| # FIXME(-LAN-): Save current user before entering new app context | |||||
| from flask import g | |||||
| saved_user = None | |||||
| if has_request_context() and hasattr(g, "_login_user"): | |||||
| saved_user = g._login_user | |||||
| with flask_app.app_context(): | |||||
| with preserve_flask_contexts(flask_app, context_vars=context): | |||||
| try: | try: | ||||
| # Restore user in new app context | |||||
| if saved_user is not None: | |||||
| from flask import g | |||||
| g._login_user = saved_user | |||||
| # workflow app | # workflow app | ||||
| runner = WorkflowAppRunner( | runner = WorkflowAppRunner( | ||||
| application_generate_entity=application_generate_entity, | application_generate_entity=application_generate_entity, |
| from datetime import UTC, datetime | from datetime import UTC, datetime | ||||
| from typing import Any, Optional, cast | from typing import Any, Optional, cast | ||||
| from flask import Flask, current_app, has_request_context | |||||
| from flask import Flask, current_app | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError | from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError | ||||
| from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle | from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle | ||||
| from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent | from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent | ||||
| from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING | from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING | ||||
| from libs.flask_utils import preserve_flask_contexts | |||||
| from models.enums import UserFrom | from models.enums import UserFrom | ||||
| from models.workflow import WorkflowType | from models.workflow import WorkflowType | ||||
| """ | """ | ||||
| Run parallel nodes | Run parallel nodes | ||||
| """ | """ | ||||
| for var, val in context.items(): | |||||
| var.set(val) | |||||
| # FIXME(-LAN-): Save current user before entering new app context | |||||
| from flask import g | |||||
| saved_user = None | |||||
| if has_request_context() and hasattr(g, "_login_user"): | |||||
| saved_user = g._login_user | |||||
| with flask_app.app_context(): | |||||
| with preserve_flask_contexts(flask_app, context_vars=context): | |||||
| try: | try: | ||||
| # Restore user in new app context | |||||
| if saved_user is not None: | |||||
| from flask import g | |||||
| g._login_user = saved_user | |||||
| q.put( | q.put( | ||||
| ParallelBranchRunStartedEvent( | ParallelBranchRunStartedEvent( | ||||
| parallel_id=parallel_id, | parallel_id=parallel_id, |
| from queue import Empty, Queue | from queue import Empty, Queue | ||||
| from typing import TYPE_CHECKING, Any, Optional, cast | from typing import TYPE_CHECKING, Any, Optional, cast | ||||
| from flask import Flask, current_app, has_request_context | |||||
| from flask import Flask, current_app | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.variables import ArrayVariable, IntegerVariable, NoneVariable | from core.variables import ArrayVariable, IntegerVariable, NoneVariable | ||||
| from core.workflow.nodes.enums import NodeType | from core.workflow.nodes.enums import NodeType | ||||
| from core.workflow.nodes.event import NodeEvent, RunCompletedEvent | from core.workflow.nodes.event import NodeEvent, RunCompletedEvent | ||||
| from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData | from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData | ||||
| from libs.flask_utils import preserve_flask_contexts | |||||
| from .exc import ( | from .exc import ( | ||||
| InvalidIteratorValueError, | InvalidIteratorValueError, | ||||
| """ | """ | ||||
| run single iteration in parallel mode | run single iteration in parallel mode | ||||
| """ | """ | ||||
| for var, val in context.items(): | |||||
| var.set(val) | |||||
| # FIXME(-LAN-): Save current user before entering new app context | |||||
| from flask import g | |||||
| saved_user = None | |||||
| if has_request_context() and hasattr(g, "_login_user"): | |||||
| saved_user = g._login_user | |||||
| with flask_app.app_context(): | |||||
| # Restore user in new app context | |||||
| if saved_user is not None: | |||||
| from flask import g | |||||
| g._login_user = saved_user | |||||
| with preserve_flask_contexts(flask_app, context_vars=context): | |||||
| parallel_mode_run_id = uuid.uuid4().hex | parallel_mode_run_id = uuid.uuid4().hex | ||||
| graph_engine_copy = graph_engine.create_copy() | graph_engine_copy = graph_engine.create_copy() | ||||
| variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool | variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool |
| import contextvars | |||||
| from collections.abc import Iterator | |||||
| from contextlib import contextmanager | |||||
| from typing import TypeVar | |||||
| from flask import Flask, g, has_request_context | |||||
| T = TypeVar("T") | |||||
| @contextmanager | |||||
| def preserve_flask_contexts( | |||||
| flask_app: Flask, | |||||
| context_vars: contextvars.Context, | |||||
| ) -> Iterator[None]: | |||||
| """ | |||||
| A context manager that handles: | |||||
| 1. flask-login's UserProxy copy | |||||
| 2. ContextVars copy | |||||
| 3. flask_app.app_context() | |||||
| This context manager ensures that the Flask application context is properly set up, | |||||
| the current user is preserved across context boundaries, and any provided context variables | |||||
| are set within the new context. | |||||
| Note: | |||||
| This manager aims to allow use current_user cross thread and app context, | |||||
| but it's not the recommend use, it's better to pass user directly in parameters. | |||||
| Args: | |||||
| flask_app: The Flask application instance | |||||
| context_vars: contextvars.Context object containing context variables to be set in the new context | |||||
| Yields: | |||||
| None | |||||
| Example: | |||||
| ```python | |||||
| with preserve_flask_contexts(flask_app, context_vars=context_vars): | |||||
| # Code that needs Flask app context and context variables | |||||
| # Current user will be preserved if available | |||||
| ``` | |||||
| """ | |||||
| # Set context variables if provided | |||||
| if context_vars: | |||||
| for var, val in context_vars.items(): | |||||
| var.set(val) | |||||
| # Save current user before entering new app context | |||||
| saved_user = None | |||||
| if has_request_context() and hasattr(g, "_login_user"): | |||||
| saved_user = g._login_user | |||||
| # Enter Flask app context | |||||
| with flask_app.app_context(): | |||||
| try: | |||||
| # Restore user in new app context if it was saved | |||||
| if saved_user is not None: | |||||
| g._login_user = saved_user | |||||
| # Yield control back to the caller | |||||
| yield | |||||
| finally: | |||||
| # Any cleanup can be added here if needed | |||||
| pass |
| import contextvars | |||||
| import threading | |||||
| from typing import Optional | |||||
| import pytest | |||||
| from flask import Flask | |||||
| from flask_login import LoginManager, UserMixin, current_user, login_user | |||||
| from libs.flask_utils import preserve_flask_contexts | |||||
| class User(UserMixin): | |||||
| """Simple User class for testing.""" | |||||
| def __init__(self, id: str): | |||||
| self.id = id | |||||
| def get_id(self) -> str: | |||||
| return self.id | |||||
| @pytest.fixture | |||||
| def login_app(app: Flask) -> Flask: | |||||
| """Set up a Flask app with flask-login.""" | |||||
| # Set a secret key for the app | |||||
| app.config["SECRET_KEY"] = "test-secret-key" | |||||
| login_manager = LoginManager() | |||||
| login_manager.init_app(app) | |||||
| @login_manager.user_loader | |||||
| def load_user(user_id: str) -> Optional[User]: | |||||
| if user_id == "test_user": | |||||
| return User("test_user") | |||||
| return None | |||||
| return app | |||||
| @pytest.fixture | |||||
| def test_user() -> User: | |||||
| """Create a test user.""" | |||||
| return User("test_user") | |||||
| def test_current_user_not_accessible_across_threads(login_app: Flask, test_user: User): | |||||
| """ | |||||
| Test that current_user is not accessible in a different thread without preserve_flask_contexts. | |||||
| This test demonstrates that without the preserve_flask_contexts, we cannot access | |||||
| current_user in a different thread, even with app_context. | |||||
| """ | |||||
| # Log in the user in the main thread | |||||
| with login_app.test_request_context(): | |||||
| login_user(test_user) | |||||
| assert current_user.is_authenticated | |||||
| assert current_user.id == "test_user" | |||||
| # Store the result of the thread execution | |||||
| result = {"user_accessible": True, "error": None} | |||||
| # Define a function to run in a separate thread | |||||
| def check_user_in_thread(): | |||||
| try: | |||||
| # Try to access current_user in a different thread with app_context | |||||
| with login_app.app_context(): | |||||
| # This should fail because current_user is not accessible across threads | |||||
| # without preserve_flask_contexts | |||||
| result["user_accessible"] = current_user.is_authenticated | |||||
| except Exception as e: | |||||
| result["error"] = str(e) # type: ignore | |||||
| # Run the function in a separate thread | |||||
| thread = threading.Thread(target=check_user_in_thread) | |||||
| thread.start() | |||||
| thread.join() | |||||
| # Verify that we got an error or current_user is not authenticated | |||||
| assert result["error"] is not None or (result["user_accessible"] is not None and not result["user_accessible"]) | |||||
| def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask, test_user: User): | |||||
| """ | |||||
| Test that current_user is accessible in a different thread with preserve_flask_contexts. | |||||
| This test demonstrates that with the preserve_flask_contexts, we can access | |||||
| current_user in a different thread. | |||||
| """ | |||||
| # Log in the user in the main thread | |||||
| with login_app.test_request_context(): | |||||
| login_user(test_user) | |||||
| assert current_user.is_authenticated | |||||
| assert current_user.id == "test_user" | |||||
| # Save the context variables | |||||
| context_vars = contextvars.copy_context() | |||||
| # Store the result of the thread execution | |||||
| result = {"user_accessible": False, "user_id": None, "error": None} | |||||
| # Define a function to run in a separate thread | |||||
| def check_user_in_thread_with_manager(): | |||||
| try: | |||||
| # Use preserve_flask_contexts to access current_user in a different thread | |||||
| with preserve_flask_contexts(login_app, context_vars): | |||||
| from flask_login import current_user | |||||
| if current_user: | |||||
| result["user_accessible"] = True | |||||
| result["user_id"] = current_user.id | |||||
| else: | |||||
| result["user_accessible"] = False | |||||
| except Exception as e: | |||||
| result["error"] = str(e) # type: ignore | |||||
| # Run the function in a separate thread | |||||
| thread = threading.Thread(target=check_user_in_thread_with_manager) | |||||
| thread.start() | |||||
| thread.join() | |||||
| # Verify that current_user is accessible and has the correct ID | |||||
| assert result["error"] is None | |||||
| assert result["user_accessible"] is True | |||||
| assert result["user_id"] == "test_user" |