Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.4.1
| @@ -11,10 +11,6 @@ if TYPE_CHECKING: | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| tenant_id: ContextVar[str] = ContextVar("tenant_id") | |||
| workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") | |||
| """ | |||
| To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with | |||
| """ | |||
| @@ -3,7 +3,7 @@ from flask_restful import Resource, marshal, marshal_with, reqparse | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.service_api import api | |||
| from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token | |||
| from controllers.service_api.wraps import validate_app_token | |||
| from extensions.ext_redis import redis_client | |||
| from fields.annotation_fields import ( | |||
| annotation_fields, | |||
| @@ -14,7 +14,7 @@ from services.annotation_service import AppAnnotationService | |||
| class AnnotationReplyActionApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) | |||
| @validate_app_token | |||
| def post(self, app_model: App, end_user: EndUser, action): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("score_threshold", required=True, type=float, location="json") | |||
| @@ -31,7 +31,7 @@ class AnnotationReplyActionApi(Resource): | |||
| class AnnotationReplyActionStatusApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) | |||
| @validate_app_token | |||
| def get(self, app_model: App, end_user: EndUser, job_id, action): | |||
| job_id = str(job_id) | |||
| app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) | |||
| @@ -49,7 +49,7 @@ class AnnotationReplyActionStatusApi(Resource): | |||
| class AnnotationListApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) | |||
| @validate_app_token | |||
| def get(self, app_model: App, end_user: EndUser): | |||
| page = request.args.get("page", default=1, type=int) | |||
| limit = request.args.get("limit", default=20, type=int) | |||
| @@ -65,7 +65,7 @@ class AnnotationListApi(Resource): | |||
| } | |||
| return response, 200 | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) | |||
| @validate_app_token | |||
| @marshal_with(annotation_fields) | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| parser = reqparse.RequestParser() | |||
| @@ -77,7 +77,7 @@ class AnnotationListApi(Resource): | |||
| class AnnotationUpdateDeleteApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) | |||
| @validate_app_token | |||
| @marshal_with(annotation_fields) | |||
| def put(self, app_model: App, end_user: EndUser, annotation_id): | |||
| if not current_user.is_editor: | |||
| @@ -91,7 +91,7 @@ class AnnotationUpdateDeleteApi(Resource): | |||
| annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) | |||
| return annotation | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) | |||
| @validate_app_token | |||
| def delete(self, app_model: App, end_user: EndUser, annotation_id): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -99,7 +99,12 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| if user_id: | |||
| user_id = str(user_id) | |||
| kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id) | |||
| end_user = create_or_update_end_user_for_user_id(app_model, user_id) | |||
| kwargs["end_user"] = end_user | |||
| # Set EndUser as current logged-in user for flask_login.current_user | |||
| current_app.login_manager._update_request_context_with_user(end_user) # type: ignore | |||
| user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore | |||
| return view_func(*args, **kwargs) | |||
| @@ -5,7 +5,7 @@ import uuid | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Any, Literal, Optional, Union, overload | |||
| from flask import Flask, current_app | |||
| from flask import Flask, copy_current_request_context, current_app, has_request_context | |||
| from pydantic import ValidationError | |||
| from sqlalchemy.orm import sessionmaker | |||
| @@ -158,7 +158,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| trace_manager=trace_manager, | |||
| workflow_run_id=workflow_run_id, | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| @@ -240,7 +239,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| node_id=node_id, inputs=args["inputs"] | |||
| ), | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| @@ -316,7 +314,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| extras={"auto_generate_conversation_name": False}, | |||
| single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| @@ -399,18 +396,23 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| message_id=message.id, | |||
| ) | |||
| # new thread | |||
| worker_thread = threading.Thread( | |||
| target=self._generate_worker, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "application_generate_entity": application_generate_entity, | |||
| "queue_manager": queue_manager, | |||
| "conversation_id": conversation.id, | |||
| "message_id": message.id, | |||
| "context": contextvars.copy_context(), | |||
| }, | |||
| ) | |||
| # new thread with request context and contextvars | |||
| context = contextvars.copy_context() | |||
| @copy_current_request_context | |||
| def worker_with_context(): | |||
| # Run the worker within the copied context | |||
| return context.run( | |||
| self._generate_worker, | |||
| flask_app=current_app._get_current_object(), # type: ignore | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| conversation_id=conversation.id, | |||
| message_id=message.id, | |||
| context=context, | |||
| ) | |||
| worker_thread = threading.Thread(target=worker_with_context) | |||
| worker_thread.start() | |||
| @@ -449,8 +451,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| """ | |||
| for var, val in context.items(): | |||
| var.set(val) | |||
| # Save current user before entering new app context | |||
| from flask import g | |||
| saved_user = None | |||
| if has_request_context() and hasattr(g, "_login_user"): | |||
| saved_user = g._login_user | |||
| with flask_app.app_context(): | |||
| try: | |||
| # Restore user in new app context | |||
| if saved_user is not None: | |||
| from flask import g | |||
| g._login_user = saved_user | |||
| # get conversation and message | |||
| conversation = self._get_conversation(conversation_id) | |||
| message = self._get_message(message_id) | |||
| @@ -5,7 +5,7 @@ import uuid | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Any, Literal, Union, overload | |||
| from flask import Flask, current_app | |||
| from flask import Flask, copy_current_request_context, current_app, has_request_context | |||
| from pydantic import ValidationError | |||
| from configs import dify_config | |||
| @@ -179,18 +179,23 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| message_id=message.id, | |||
| ) | |||
| # new thread | |||
| worker_thread = threading.Thread( | |||
| target=self._generate_worker, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "context": contextvars.copy_context(), | |||
| "application_generate_entity": application_generate_entity, | |||
| "queue_manager": queue_manager, | |||
| "conversation_id": conversation.id, | |||
| "message_id": message.id, | |||
| }, | |||
| ) | |||
| # new thread with request context and contextvars | |||
| context = contextvars.copy_context() | |||
| @copy_current_request_context | |||
| def worker_with_context(): | |||
| # Run the worker within the copied context | |||
| return context.run( | |||
| self._generate_worker, | |||
| flask_app=current_app._get_current_object(), # type: ignore | |||
| context=context, | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| conversation_id=conversation.id, | |||
| message_id=message.id, | |||
| ) | |||
| worker_thread = threading.Thread(target=worker_with_context) | |||
| worker_thread.start() | |||
| @@ -227,8 +232,21 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| for var, val in context.items(): | |||
| var.set(val) | |||
| # Save current user before entering new app context | |||
| from flask import g | |||
| saved_user = None | |||
| if has_request_context() and hasattr(g, "_login_user"): | |||
| saved_user = g._login_user | |||
| with flask_app.app_context(): | |||
| try: | |||
| # Restore user in new app context | |||
| if saved_user is not None: | |||
| from flask import g | |||
| g._login_user = saved_user | |||
| # get conversation and message | |||
| conversation = self._get_conversation(conversation_id) | |||
| message = self._get_message(message_id) | |||
| @@ -4,7 +4,7 @@ import uuid | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Any, Literal, Union, overload | |||
| from flask import Flask, current_app | |||
| from flask import Flask, copy_current_request_context, current_app | |||
| from pydantic import ValidationError | |||
| from configs import dify_config | |||
| @@ -170,17 +170,18 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| message_id=message.id, | |||
| ) | |||
| # new thread | |||
| worker_thread = threading.Thread( | |||
| target=self._generate_worker, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "application_generate_entity": application_generate_entity, | |||
| "queue_manager": queue_manager, | |||
| "conversation_id": conversation.id, | |||
| "message_id": message.id, | |||
| }, | |||
| ) | |||
| # new thread with request context | |||
| @copy_current_request_context | |||
| def worker_with_context(): | |||
| return self._generate_worker( | |||
| flask_app=current_app._get_current_object(), # type: ignore | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| conversation_id=conversation.id, | |||
| message_id=message.id, | |||
| ) | |||
| worker_thread = threading.Thread(target=worker_with_context) | |||
| worker_thread.start() | |||
| @@ -4,7 +4,7 @@ import uuid | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Any, Literal, Union, overload | |||
| from flask import Flask, current_app | |||
| from flask import Flask, copy_current_request_context, current_app | |||
| from pydantic import ValidationError | |||
| from configs import dify_config | |||
| @@ -151,16 +151,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| message_id=message.id, | |||
| ) | |||
| # new thread | |||
| worker_thread = threading.Thread( | |||
| target=self._generate_worker, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "application_generate_entity": application_generate_entity, | |||
| "queue_manager": queue_manager, | |||
| "message_id": message.id, | |||
| }, | |||
| ) | |||
| # new thread with request context | |||
| @copy_current_request_context | |||
| def worker_with_context(): | |||
| return self._generate_worker( | |||
| flask_app=current_app._get_current_object(), # type: ignore | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| message_id=message.id, | |||
| ) | |||
| worker_thread = threading.Thread(target=worker_with_context) | |||
| worker_thread.start() | |||
| @@ -313,16 +314,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| message_id=message.id, | |||
| ) | |||
| # new thread | |||
| worker_thread = threading.Thread( | |||
| target=self._generate_worker, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "application_generate_entity": application_generate_entity, | |||
| "queue_manager": queue_manager, | |||
| "message_id": message.id, | |||
| }, | |||
| ) | |||
| # new thread with request context | |||
| @copy_current_request_context | |||
| def worker_with_context(): | |||
| return self._generate_worker( | |||
| flask_app=current_app._get_current_object(), # type: ignore | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| message_id=message.id, | |||
| ) | |||
| worker_thread = threading.Thread(target=worker_with_context) | |||
| worker_thread.start() | |||
| @@ -5,7 +5,7 @@ import uuid | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import Any, Literal, Optional, Union, overload | |||
| from flask import Flask, current_app | |||
| from flask import Flask, copy_current_request_context, current_app, has_request_context | |||
| from pydantic import ValidationError | |||
| from sqlalchemy.orm import sessionmaker | |||
| @@ -135,7 +135,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| workflow_run_id=workflow_run_id, | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| @@ -207,17 +206,22 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| app_mode=app_model.mode, | |||
| ) | |||
| # new thread | |||
| worker_thread = threading.Thread( | |||
| target=self._generate_worker, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "application_generate_entity": application_generate_entity, | |||
| "queue_manager": queue_manager, | |||
| "context": contextvars.copy_context(), | |||
| "workflow_thread_pool_id": workflow_thread_pool_id, | |||
| }, | |||
| ) | |||
| # new thread with request context and contextvars | |||
| context = contextvars.copy_context() | |||
| @copy_current_request_context | |||
| def worker_with_context(): | |||
| # Run the worker within the copied context | |||
| return context.run( | |||
| self._generate_worker, | |||
| flask_app=current_app._get_current_object(), # type: ignore | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| context=context, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| ) | |||
| worker_thread = threading.Thread(target=worker_with_context) | |||
| worker_thread.start() | |||
| @@ -277,7 +281,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| ), | |||
| workflow_run_id=str(uuid.uuid4()), | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| @@ -354,7 +357,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), | |||
| workflow_run_id=str(uuid.uuid4()), | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| @@ -408,8 +410,22 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| """ | |||
| for var, val in context.items(): | |||
| var.set(val) | |||
| # Save current user before entering new app context | |||
| from flask import g | |||
| saved_user = None | |||
| if has_request_context() and hasattr(g, "_login_user"): | |||
| saved_user = g._login_user | |||
| with flask_app.app_context(): | |||
| try: | |||
| # Restore user in new app context | |||
| if saved_user is not None: | |||
| from flask import g | |||
| g._login_user = saved_user | |||
| # workflow app | |||
| runner = WorkflowAppRunner( | |||
| application_generate_entity=application_generate_entity, | |||
| @@ -5,7 +5,6 @@ from flask import Response, request | |||
| from flask_login import user_loaded_from_request, user_logged_in | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| import contexts | |||
| from configs import dify_config | |||
| from dify_app import DifyApp | |||
| from extensions.ext_database import db | |||
| @@ -82,8 +81,8 @@ def on_user_logged_in(_sender, user): | |||
| Note: AccountService.load_logged_in_account will populate user.current_tenant_id | |||
| through the load_user method, which calls account.set_tenant_id(). | |||
| """ | |||
| if user and isinstance(user, Account) and user.current_tenant_id: | |||
| contexts.tenant_id.set(user.current_tenant_id) | |||
| # tenant_id context variable removed - using current_user.current_tenant_id directly | |||
| pass | |||
| @login_manager.unauthorized_handler | |||
| @@ -6,6 +6,8 @@ from enum import Enum, StrEnum | |||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||
| from uuid import uuid4 | |||
| from flask_login import current_user | |||
| from core.variables import utils as variable_utils | |||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | |||
| from factories.variable_factory import build_segment | |||
| @@ -17,7 +19,6 @@ import sqlalchemy as sa | |||
| from sqlalchemy import UniqueConstraint, func | |||
| from sqlalchemy.orm import Mapped, mapped_column | |||
| import contexts | |||
| from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE | |||
| from core.helper import encrypter | |||
| from core.variables import SecretVariable, Segment, SegmentType, Variable | |||
| @@ -274,7 +275,16 @@ class Workflow(Base): | |||
| if self._environment_variables is None: | |||
| self._environment_variables = "{}" | |||
| tenant_id = contexts.tenant_id.get() | |||
| # Get tenant_id from current_user (Account or EndUser) | |||
| if isinstance(current_user, Account): | |||
| # Account user | |||
| tenant_id = current_user.current_tenant_id | |||
| else: | |||
| # EndUser | |||
| tenant_id = current_user.tenant_id | |||
| if not tenant_id: | |||
| return [] | |||
| environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables) | |||
| results = [ | |||
| @@ -297,7 +307,17 @@ class Workflow(Base): | |||
| self._environment_variables = "{}" | |||
| return | |||
| tenant_id = contexts.tenant_id.get() | |||
| # Get tenant_id from current_user (Account or EndUser) | |||
| if isinstance(current_user, Account): | |||
| # Account user | |||
| tenant_id = current_user.current_tenant_id | |||
| else: | |||
| # EndUser | |||
| tenant_id = current_user.tenant_id | |||
| if not tenant_id: | |||
| self._environment_variables = "{}" | |||
| return | |||
| value = list(value) | |||
| if any(var for var in value if not var.id): | |||
| @@ -2,14 +2,13 @@ import json | |||
| from unittest import mock | |||
| from uuid import uuid4 | |||
| import contexts | |||
| from constants import HIDDEN_VALUE | |||
| from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable | |||
| from models.workflow import Workflow, WorkflowNodeExecution | |||
| def test_environment_variables(): | |||
| contexts.tenant_id.set("tenant_id") | |||
| # tenant_id context variable removed - using current_user.current_tenant_id directly | |||
| # Create a Workflow instance | |||
| workflow = Workflow( | |||
| @@ -38,9 +37,14 @@ def test_environment_variables(): | |||
| {"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]} | |||
| ) | |||
| # Mock current_user as an EndUser | |||
| mock_user = mock.Mock() | |||
| mock_user.tenant_id = "tenant_id" | |||
| with ( | |||
| mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), | |||
| mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), | |||
| mock.patch("models.workflow.current_user", mock_user), | |||
| ): | |||
| # Set the environment_variables property of the Workflow instance | |||
| variables = [variable1, variable2, variable3, variable4] | |||
| @@ -51,7 +55,7 @@ def test_environment_variables(): | |||
| def test_update_environment_variables(): | |||
| contexts.tenant_id.set("tenant_id") | |||
| # tenant_id context variable removed - using current_user.current_tenant_id directly | |||
| # Create a Workflow instance | |||
| workflow = Workflow( | |||
| @@ -80,9 +84,14 @@ def test_update_environment_variables(): | |||
| {"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]} | |||
| ) | |||
| # Mock current_user as an EndUser | |||
| mock_user = mock.Mock() | |||
| mock_user.tenant_id = "tenant_id" | |||
| with ( | |||
| mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), | |||
| mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), | |||
| mock.patch("models.workflow.current_user", mock_user), | |||
| ): | |||
| variables = [variable1, variable2, variable3, variable4] | |||
| @@ -104,7 +113,7 @@ def test_update_environment_variables(): | |||
| def test_to_dict(): | |||
| contexts.tenant_id.set("tenant_id") | |||
| # tenant_id context variable removed - using current_user.current_tenant_id directly | |||
| # Create a Workflow instance | |||
| workflow = Workflow( | |||
| @@ -121,9 +130,14 @@ def test_to_dict(): | |||
| # Create some EnvironmentVariable instances | |||
| # Mock current_user as an EndUser | |||
| mock_user = mock.Mock() | |||
| mock_user.tenant_id = "tenant_id" | |||
| with ( | |||
| mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), | |||
| mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), | |||
| mock.patch("models.workflow.current_user", mock_user), | |||
| ): | |||
| # Set the environment_variables property of the Workflow instance | |||
| workflow.environment_variables = [ | |||