| @@ -1,3 +1,7 @@ | |||
| from contextvars import ContextVar | |||
| tenant_id: ContextVar[str] = ContextVar('tenant_id') | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| tenant_id: ContextVar[str] = ContextVar('tenant_id') | |||
| workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool') | |||
| @@ -8,6 +8,8 @@ from typing import Union | |||
| from flask import Flask, current_app | |||
| from pydantic import ValidationError | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| import contexts | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| @@ -18,15 +20,20 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom | |||
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | |||
| from core.app.entities.app_invoke_entities import ( | |||
| AdvancedChatAppGenerateEntity, | |||
| InvokeFrom, | |||
| ) | |||
| from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse | |||
| from core.file.message_file_parser import MessageFileParser | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariable | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import App, Conversation, EndUser, Message | |||
| from models.workflow import Workflow | |||
| from models.workflow import ConversationVariable, Workflow | |||
| logger = logging.getLogger(__name__) | |||
| @@ -120,7 +127,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| conversation=conversation, | |||
| stream=stream | |||
| ) | |||
| def single_iteration_generate(self, app_model: App, | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| @@ -140,10 +147,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| """ | |||
| if not node_id: | |||
| raise ValueError('node_id is required') | |||
| if args.get('inputs') is None: | |||
| raise ValueError('inputs is required') | |||
| extras = { | |||
| "auto_generate_conversation_name": False | |||
| } | |||
| @@ -209,7 +216,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| # update conversation features | |||
| conversation.override_model_configs = workflow.features | |||
| db.session.commit() | |||
| db.session.refresh(conversation) | |||
| # db.session.refresh(conversation) | |||
| # init queue manager | |||
| queue_manager = MessageBasedAppQueueManager( | |||
| @@ -221,15 +228,69 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| message_id=message.id | |||
| ) | |||
| # Init conversation variables | |||
| stmt = select(ConversationVariable).where( | |||
| ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id | |||
| ) | |||
| with Session(db.engine) as session: | |||
| conversation_variables = session.scalars(stmt).all() | |||
| if not conversation_variables: | |||
| # Create conversation variables if they don't exist. | |||
| conversation_variables = [ | |||
| ConversationVariable.from_variable( | |||
| app_id=conversation.app_id, conversation_id=conversation.id, variable=variable | |||
| ) | |||
| for variable in workflow.conversation_variables | |||
| ] | |||
| session.add_all(conversation_variables) | |||
| # Convert database entities to variables. | |||
| conversation_variables = [item.to_variable() for item in conversation_variables] | |||
| session.commit() | |||
| # Increment dialogue count. | |||
| conversation.dialogue_count += 1 | |||
| conversation_id = conversation.id | |||
| conversation_dialogue_count = conversation.dialogue_count | |||
| db.session.commit() | |||
| db.session.refresh(conversation) | |||
| inputs = application_generate_entity.inputs | |||
| query = application_generate_entity.query | |||
| files = application_generate_entity.files | |||
| user_id = None | |||
| if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() | |||
| if end_user: | |||
| user_id = end_user.session_id | |||
| else: | |||
| user_id = application_generate_entity.user_id | |||
| # Create a variable pool. | |||
| system_inputs = { | |||
| SystemVariable.QUERY: query, | |||
| SystemVariable.FILES: files, | |||
| SystemVariable.CONVERSATION_ID: conversation_id, | |||
| SystemVariable.USER_ID: user_id, | |||
| SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count, | |||
| } | |||
| variable_pool = VariablePool( | |||
| system_variables=system_inputs, | |||
| user_inputs=inputs, | |||
| environment_variables=workflow.environment_variables, | |||
| conversation_variables=conversation_variables, | |||
| ) | |||
| contexts.workflow_variable_pool.set(variable_pool) | |||
| # new thread | |||
| worker_thread = threading.Thread(target=self._generate_worker, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'application_generate_entity': application_generate_entity, | |||
| 'queue_manager': queue_manager, | |||
| 'conversation_id': conversation.id, | |||
| 'message_id': message.id, | |||
| 'user': user, | |||
| 'context': contextvars.copy_context() | |||
| 'context': contextvars.copy_context(), | |||
| }) | |||
| worker_thread.start() | |||
| @@ -242,7 +303,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| conversation=conversation, | |||
| message=message, | |||
| user=user, | |||
| stream=stream | |||
| stream=stream, | |||
| ) | |||
| return AdvancedChatAppGenerateResponseConverter.convert( | |||
| @@ -253,9 +314,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| def _generate_worker(self, flask_app: Flask, | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| conversation_id: str, | |||
| message_id: str, | |||
| user: Account, | |||
| context: contextvars.Context) -> None: | |||
| """ | |||
| Generate worker in a new thread. | |||
| @@ -282,8 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| user_id=application_generate_entity.user_id | |||
| ) | |||
| else: | |||
| # get conversation and message | |||
| conversation = self._get_conversation(conversation_id) | |||
| # get message | |||
| message = self._get_message(message_id) | |||
| # chatbot app | |||
| @@ -291,7 +349,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| runner.run( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| conversation=conversation, | |||
| message=message | |||
| ) | |||
| except GenerateTaskStoppedException: | |||
| @@ -314,14 +371,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| finally: | |||
| db.session.close() | |||
| def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| workflow: Workflow, | |||
| queue_manager: AppQueueManager, | |||
| conversation: Conversation, | |||
| message: Message, | |||
| user: Union[Account, EndUser], | |||
| stream: bool = False) \ | |||
| -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: | |||
| def _handle_advanced_chat_response( | |||
| self, | |||
| *, | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| workflow: Workflow, | |||
| queue_manager: AppQueueManager, | |||
| conversation: Conversation, | |||
| message: Message, | |||
| user: Union[Account, EndUser], | |||
| stream: bool = False, | |||
| ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: | |||
| """ | |||
| Handle response. | |||
| :param application_generate_entity: application generate entity | |||
| @@ -341,7 +401,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| conversation=conversation, | |||
| message=message, | |||
| user=user, | |||
| stream=stream | |||
| stream=stream, | |||
| ) | |||
| try: | |||
| @@ -4,9 +4,6 @@ import time | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional, cast | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig | |||
| from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| @@ -19,13 +16,10 @@ from core.app.entities.app_invoke_entities import ( | |||
| from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent | |||
| from core.moderation.base import ModerationException | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| from core.workflow.entities.node_entities import SystemVariable | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.workflow_engine_manager import WorkflowEngineManager | |||
| from extensions.ext_database import db | |||
| from models.model import App, Conversation, EndUser, Message | |||
| from models.workflow import ConversationVariable, Workflow | |||
| from models import App, Message, Workflow | |||
| logger = logging.getLogger(__name__) | |||
| @@ -39,7 +33,6 @@ class AdvancedChatAppRunner(AppRunner): | |||
| self, | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| conversation: Conversation, | |||
| message: Message, | |||
| ) -> None: | |||
| """ | |||
| @@ -63,15 +56,6 @@ class AdvancedChatAppRunner(AppRunner): | |||
| inputs = application_generate_entity.inputs | |||
| query = application_generate_entity.query | |||
| files = application_generate_entity.files | |||
| user_id = None | |||
| if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() | |||
| if end_user: | |||
| user_id = end_user.session_id | |||
| else: | |||
| user_id = application_generate_entity.user_id | |||
| # moderation | |||
| if self.handle_input_moderation( | |||
| @@ -103,38 +87,6 @@ class AdvancedChatAppRunner(AppRunner): | |||
| if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): | |||
| workflow_callbacks.append(WorkflowLoggingCallback()) | |||
| # Init conversation variables | |||
| stmt = select(ConversationVariable).where( | |||
| ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id | |||
| ) | |||
| with Session(db.engine) as session: | |||
| conversation_variables = session.scalars(stmt).all() | |||
| if not conversation_variables: | |||
| conversation_variables = [ | |||
| ConversationVariable.from_variable( | |||
| app_id=conversation.app_id, conversation_id=conversation.id, variable=variable | |||
| ) | |||
| for variable in workflow.conversation_variables | |||
| ] | |||
| session.add_all(conversation_variables) | |||
| session.commit() | |||
| # Convert database entities to variables | |||
| conversation_variables = [item.to_variable() for item in conversation_variables] | |||
| # Create a variable pool. | |||
| system_inputs = { | |||
| SystemVariable.QUERY: query, | |||
| SystemVariable.FILES: files, | |||
| SystemVariable.CONVERSATION_ID: conversation.id, | |||
| SystemVariable.USER_ID: user_id, | |||
| } | |||
| variable_pool = VariablePool( | |||
| system_variables=system_inputs, | |||
| user_inputs=inputs, | |||
| environment_variables=workflow.environment_variables, | |||
| conversation_variables=conversation_variables, | |||
| ) | |||
| # RUN WORKFLOW | |||
| workflow_engine_manager = WorkflowEngineManager() | |||
| workflow_engine_manager.run_workflow( | |||
| @@ -146,7 +98,6 @@ class AdvancedChatAppRunner(AppRunner): | |||
| invoke_from=application_generate_entity.invoke_from, | |||
| callbacks=workflow_callbacks, | |||
| call_depth=application_generate_entity.call_depth, | |||
| variable_pool=variable_pool, | |||
| ) | |||
| def single_iteration_run( | |||
| @@ -155,7 +106,7 @@ class AdvancedChatAppRunner(AppRunner): | |||
| """ | |||
| Single iteration run | |||
| """ | |||
| app_record: App = db.session.query(App).filter(App.id == app_id).first() | |||
| app_record = db.session.query(App).filter(App.id == app_id).first() | |||
| if not app_record: | |||
| raise ValueError('App not found') | |||
| @@ -4,6 +4,7 @@ import time | |||
| from collections.abc import Generator | |||
| from typing import Any, Optional, Union, cast | |||
| import contexts | |||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | |||
| from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| @@ -47,7 +48,8 @@ from core.file.file_obj import FileVar | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.node_entities import NodeType, SystemVariable | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.enums import SystemVariable | |||
| from core.workflow.nodes.answer.answer_node import AnswerNode | |||
| from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk | |||
| from events.message_event import message_was_created | |||
| @@ -71,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| _application_generate_entity: AdvancedChatAppGenerateEntity | |||
| _workflow: Workflow | |||
| _user: Union[Account, EndUser] | |||
| # Deprecated | |||
| _workflow_system_variables: dict[SystemVariable, Any] | |||
| _iteration_nested_relations: dict[str, list[str]] | |||
| @@ -81,7 +84,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| conversation: Conversation, | |||
| message: Message, | |||
| user: Union[Account, EndUser], | |||
| stream: bool | |||
| stream: bool, | |||
| ) -> None: | |||
| """ | |||
| Initialize AdvancedChatAppGenerateTaskPipeline. | |||
| @@ -103,11 +106,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| self._workflow = workflow | |||
| self._conversation = conversation | |||
| self._message = message | |||
| # Deprecated | |||
| self._workflow_system_variables = { | |||
| SystemVariable.QUERY: message.query, | |||
| SystemVariable.FILES: application_generate_entity.files, | |||
| SystemVariable.CONVERSATION_ID: conversation.id, | |||
| SystemVariable.USER_ID: user_id | |||
| SystemVariable.USER_ID: user_id, | |||
| } | |||
| self._task_state = AdvancedChatTaskState( | |||
| @@ -613,7 +617,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if route_chunk_node_id == 'sys': | |||
| # system variable | |||
| value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1])) | |||
| value = contexts.workflow_variable_pool.get().get(value_selector) | |||
| if value: | |||
| value = value.text | |||
| elif route_chunk_node_id in self._iteration_nested_relations: | |||
| # it's a iteration variable | |||
| if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: | |||
| @@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| return introduction | |||
| def _get_conversation(self, conversation_id: str) -> Conversation: | |||
| def _get_conversation(self, conversation_id: str): | |||
| """ | |||
| Get conversation by conversation id | |||
| :param conversation_id: conversation id | |||
| @@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| .first() | |||
| ) | |||
| if not conversation: | |||
| raise ConversationNotExistsError() | |||
| return conversation | |||
| def _get_message(self, message_id: str) -> Message: | |||
| @@ -11,8 +11,8 @@ from core.app.entities.app_invoke_entities import ( | |||
| WorkflowAppGenerateEntity, | |||
| ) | |||
| from core.workflow.callbacks.base_workflow_callback import WorkflowCallback | |||
| from core.workflow.entities.node_entities import SystemVariable | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariable | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.workflow_engine_manager import WorkflowEngineManager | |||
| from extensions.ext_database import db | |||
| @@ -42,7 +42,8 @@ from core.app.entities.task_entities import ( | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.node_entities import NodeType, SystemVariable | |||
| from core.workflow.entities.node_entities import NodeType | |||
| from core.workflow.enums import SystemVariable | |||
| from core.workflow.nodes.end.end_node import EndNode | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| @@ -519,7 +520,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| """ | |||
| nodes = graph.get('nodes') | |||
| iteration_ids = [node.get('id') for node in nodes | |||
| iteration_ids = [node.get('id') for node in nodes | |||
| if node.get('data', {}).get('type') in [ | |||
| NodeType.ITERATION.value, | |||
| NodeType.LOOP.value, | |||
| @@ -530,4 +531,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id | |||
| ] for iteration_id in iteration_ids | |||
| } | |||
| @@ -2,7 +2,6 @@ from .segment_group import SegmentGroup | |||
| from .segments import ( | |||
| ArrayAnySegment, | |||
| ArraySegment, | |||
| FileSegment, | |||
| FloatSegment, | |||
| IntegerSegment, | |||
| NoneSegment, | |||
| @@ -13,11 +12,9 @@ from .segments import ( | |||
| from .types import SegmentType | |||
| from .variables import ( | |||
| ArrayAnyVariable, | |||
| ArrayFileVariable, | |||
| ArrayNumberVariable, | |||
| ArrayObjectVariable, | |||
| ArrayStringVariable, | |||
| FileVariable, | |||
| FloatVariable, | |||
| IntegerVariable, | |||
| NoneVariable, | |||
| @@ -32,7 +29,6 @@ __all__ = [ | |||
| 'FloatVariable', | |||
| 'ObjectVariable', | |||
| 'SecretVariable', | |||
| 'FileVariable', | |||
| 'StringVariable', | |||
| 'ArrayAnyVariable', | |||
| 'Variable', | |||
| @@ -45,11 +41,9 @@ __all__ = [ | |||
| 'FloatSegment', | |||
| 'ObjectSegment', | |||
| 'ArrayAnySegment', | |||
| 'FileSegment', | |||
| 'StringSegment', | |||
| 'ArrayStringVariable', | |||
| 'ArrayNumberVariable', | |||
| 'ArrayObjectVariable', | |||
| 'ArrayFileVariable', | |||
| 'ArraySegment', | |||
| ] | |||
| @@ -2,12 +2,10 @@ from collections.abc import Mapping | |||
| from typing import Any | |||
| from configs import dify_config | |||
| from core.file.file_obj import FileVar | |||
| from .exc import VariableError | |||
| from .segments import ( | |||
| ArrayAnySegment, | |||
| FileSegment, | |||
| FloatSegment, | |||
| IntegerSegment, | |||
| NoneSegment, | |||
| @@ -17,11 +15,9 @@ from .segments import ( | |||
| ) | |||
| from .types import SegmentType | |||
| from .variables import ( | |||
| ArrayFileVariable, | |||
| ArrayNumberVariable, | |||
| ArrayObjectVariable, | |||
| ArrayStringVariable, | |||
| FileVariable, | |||
| FloatVariable, | |||
| IntegerVariable, | |||
| ObjectVariable, | |||
| @@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: | |||
| result = FloatVariable.model_validate(mapping) | |||
| case SegmentType.NUMBER if not isinstance(value, float | int): | |||
| raise VariableError(f'invalid number value {value}') | |||
| case SegmentType.FILE: | |||
| result = FileVariable.model_validate(mapping) | |||
| case SegmentType.OBJECT if isinstance(value, dict): | |||
| result = ObjectVariable.model_validate(mapping) | |||
| case SegmentType.ARRAY_STRING if isinstance(value, list): | |||
| @@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: | |||
| result = ArrayNumberVariable.model_validate(mapping) | |||
| case SegmentType.ARRAY_OBJECT if isinstance(value, list): | |||
| result = ArrayObjectVariable.model_validate(mapping) | |||
| case SegmentType.ARRAY_FILE if isinstance(value, list): | |||
| mapping = dict(mapping) | |||
| mapping['value'] = [{'value': v} for v in value] | |||
| result = ArrayFileVariable.model_validate(mapping) | |||
| case _: | |||
| raise VariableError(f'not supported value type {value_type}') | |||
| if result.size > dify_config.MAX_VARIABLE_SIZE: | |||
| @@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment: | |||
| return ObjectSegment(value=value) | |||
| if isinstance(value, list): | |||
| return ArrayAnySegment(value=value) | |||
| if isinstance(value, FileVar): | |||
| return FileSegment(value=value) | |||
| raise ValueError(f'not supported value {value}') | |||
| @@ -5,8 +5,6 @@ from typing import Any | |||
| from pydantic import BaseModel, ConfigDict, field_validator | |||
| from core.file.file_obj import FileVar | |||
| from .types import SegmentType | |||
| @@ -78,14 +76,7 @@ class IntegerSegment(Segment): | |||
| value: int | |||
| class FileSegment(Segment): | |||
| value_type: SegmentType = SegmentType.FILE | |||
| # TODO: embed FileVar in this model. | |||
| value: FileVar | |||
| @property | |||
| def markdown(self) -> str: | |||
| return self.value.to_markdown() | |||
| class ObjectSegment(Segment): | |||
| @@ -130,7 +121,3 @@ class ArrayObjectSegment(ArraySegment): | |||
| value_type: SegmentType = SegmentType.ARRAY_OBJECT | |||
| value: Sequence[Mapping[str, Any]] | |||
| class ArrayFileSegment(ArraySegment): | |||
| value_type: SegmentType = SegmentType.ARRAY_FILE | |||
| value: Sequence[FileSegment] | |||
| @@ -10,8 +10,6 @@ class SegmentType(str, Enum): | |||
| ARRAY_STRING = 'array[string]' | |||
| ARRAY_NUMBER = 'array[number]' | |||
| ARRAY_OBJECT = 'array[object]' | |||
| ARRAY_FILE = 'array[file]' | |||
| OBJECT = 'object' | |||
| FILE = 'file' | |||
| GROUP = 'group' | |||
| @@ -4,11 +4,9 @@ from core.helper import encrypter | |||
| from .segments import ( | |||
| ArrayAnySegment, | |||
| ArrayFileSegment, | |||
| ArrayNumberSegment, | |||
| ArrayObjectSegment, | |||
| ArrayStringSegment, | |||
| FileSegment, | |||
| FloatSegment, | |||
| IntegerSegment, | |||
| NoneSegment, | |||
| @@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable): | |||
| pass | |||
| class FileVariable(FileSegment, Variable): | |||
| pass | |||
| class ObjectVariable(ObjectSegment, Variable): | |||
| pass | |||
| @@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable): | |||
| pass | |||
| class ArrayFileVariable(ArrayFileSegment, Variable): | |||
| pass | |||
| class SecretVariable(StringVariable): | |||
| value_type: SegmentType = SegmentType.SECRET | |||
| @@ -2,7 +2,7 @@ from typing import Any, Union | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity | |||
| from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState | |||
| from core.workflow.entities.node_entities import SystemVariable | |||
| from core.workflow.enums import SystemVariable | |||
| from models.account import Account | |||
| from models.model import EndUser | |||
| from models.workflow import Workflow | |||
| @@ -13,4 +13,4 @@ class WorkflowCycleStateManager: | |||
| _workflow: Workflow | |||
| _user: Union[Account, EndUser] | |||
| _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] | |||
| _workflow_system_variables: dict[SystemVariable, Any] | |||
| _workflow_system_variables: dict[SystemVariable, Any] | |||
| @@ -4,13 +4,14 @@ from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from models import WorkflowNodeExecutionStatus | |||
| class NodeType(Enum): | |||
| """ | |||
| Node Types. | |||
| """ | |||
| START = 'start' | |||
| END = 'end' | |||
| ANSWER = 'answer' | |||
| @@ -44,33 +45,11 @@ class NodeType(Enum): | |||
| raise ValueError(f'invalid node type value {value}') | |||
| class SystemVariable(Enum): | |||
| """ | |||
| System Variables. | |||
| """ | |||
| QUERY = 'query' | |||
| FILES = 'files' | |||
| CONVERSATION_ID = 'conversation_id' | |||
| USER_ID = 'user_id' | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'SystemVariable': | |||
| """ | |||
| Get value of given system variable. | |||
| :param value: system variable value | |||
| :return: system variable | |||
| """ | |||
| for system_variable in cls: | |||
| if system_variable.value == value: | |||
| return system_variable | |||
| raise ValueError(f'invalid system variable value {value}') | |||
| class NodeRunMetadataKey(Enum): | |||
| """ | |||
| Node Run Metadata Key. | |||
| """ | |||
| TOTAL_TOKENS = 'total_tokens' | |||
| TOTAL_PRICE = 'total_price' | |||
| CURRENCY = 'currency' | |||
| @@ -83,6 +62,7 @@ class NodeRunResult(BaseModel): | |||
| """ | |||
| Node Run Result. | |||
| """ | |||
| status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING | |||
| inputs: Optional[Mapping[str, Any]] = None # node inputs | |||
| @@ -6,7 +6,7 @@ from typing_extensions import deprecated | |||
| from core.app.segments import Segment, Variable, factory | |||
| from core.file.file_obj import FileVar | |||
| from core.workflow.entities.node_entities import SystemVariable | |||
| from core.workflow.enums import SystemVariable | |||
| VariableValue = Union[str, int, float, dict, list, FileVar] | |||
| @@ -0,0 +1,25 @@ | |||
| from enum import Enum | |||
| class SystemVariable(str, Enum): | |||
| """ | |||
| System Variables. | |||
| """ | |||
| QUERY = 'query' | |||
| FILES = 'files' | |||
| CONVERSATION_ID = 'conversation_id' | |||
| USER_ID = 'user_id' | |||
| DIALOGUE_COUNT = 'dialogue_count' | |||
| @classmethod | |||
| def value_of(cls, value: str): | |||
| """ | |||
| Get value of given system variable. | |||
| :param value: system variable value | |||
| :return: system variable | |||
| """ | |||
| for system_variable in cls: | |||
| if system_variable.value == value: | |||
| return system_variable | |||
| raise ValueError(f'invalid system variable value {value}') | |||
| @@ -23,8 +23,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | |||
| from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariable | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.llm.entities import ( | |||
| LLMNodeChatModelMessage, | |||
| @@ -201,8 +202,8 @@ class LLMNode(BaseNode): | |||
| usage = LLMUsage.empty_usage() | |||
| return full_text, usage | |||
| def _transform_chat_messages(self, | |||
| def _transform_chat_messages(self, | |||
| messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate | |||
| ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: | |||
| """ | |||
| @@ -249,13 +250,13 @@ class LLMNode(BaseNode): | |||
| # check if it's a context structure | |||
| if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: | |||
| return d['content'] | |||
| # else, parse the dict | |||
| try: | |||
| return json.dumps(d, ensure_ascii=False) | |||
| except Exception: | |||
| return str(d) | |||
| if isinstance(value, str): | |||
| value = value | |||
| elif isinstance(value, list): | |||
| @@ -2,19 +2,20 @@ from collections.abc import Mapping, Sequence | |||
| from os import path | |||
| from typing import Any, cast | |||
| from core.app.segments import parser | |||
| from core.app.segments import ArrayAnyVariable, parser | |||
| from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler | |||
| from core.file.file_obj import FileTransferMethod, FileType, FileVar | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter | |||
| from core.tools.tool_engine import ToolEngine | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.tools.utils.message_transformer import ToolFileMessageTransformer | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariable | |||
| from core.workflow.nodes.base_node import BaseNode | |||
| from core.workflow.nodes.tool.entities import ToolNodeData | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from models import WorkflowNodeExecutionStatus | |||
| class ToolNode(BaseNode): | |||
| @@ -140,9 +141,9 @@ class ToolNode(BaseNode): | |||
| return result | |||
| def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: | |||
| # FIXME: ensure this is a ArrayVariable contains FileVariable. | |||
| variable = variable_pool.get(['sys', SystemVariable.FILES.value]) | |||
| return [file_var.value for file_var in variable.value] if variable else [] | |||
| assert isinstance(variable, ArrayAnyVariable) | |||
| return list(variable.value) if variable else [] | |||
| def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): | |||
| """ | |||
| @@ -3,6 +3,7 @@ import time | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, Optional, cast | |||
| import contexts | |||
| from configs import dify_config | |||
| from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| @@ -97,7 +98,7 @@ class WorkflowEngineManager: | |||
| invoke_from: InvokeFrom, | |||
| callbacks: Sequence[WorkflowCallback], | |||
| call_depth: int = 0, | |||
| variable_pool: VariablePool, | |||
| variable_pool: VariablePool | None = None, | |||
| ) -> None: | |||
| """ | |||
| :param workflow: Workflow instance | |||
| @@ -128,6 +129,8 @@ class WorkflowEngineManager: | |||
| raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) | |||
| # init workflow run state | |||
| if not variable_pool: | |||
| variable_pool = contexts.workflow_variable_pool.get() | |||
| workflow_run_state = WorkflowRunState( | |||
| workflow=workflow, | |||
| start_at=time.perf_counter(), | |||
| @@ -0,0 +1,33 @@ | |||
| """add conversations.dialogue_count | |||
| Revision ID: 8782057ff0dc | |||
| Revises: 63a83fcf12ba | |||
| Create Date: 2024-08-14 13:54:25.161324 | |||
| """ | |||
| import sqlalchemy as sa | |||
| from alembic import op | |||
| import models as models | |||
| # revision identifiers, used by Alembic. | |||
| revision = '8782057ff0dc' | |||
| down_revision = '63a83fcf12ba' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('conversations', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('dialogue_count', sa.Integer(), server_default='0', nullable=False)) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('conversations', schema=None) as batch_op: | |||
| batch_op.drop_column('dialogue_count') | |||
| # ### end Alembic commands ### | |||
| @@ -1,10 +1,10 @@ | |||
| from enum import Enum | |||
| from .model import AppMode | |||
| from .model import App, AppMode, Message | |||
| from .types import StringUUID | |||
| from .workflow import ConversationVariable, WorkflowNodeExecutionStatus | |||
| from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus | |||
| __all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus'] | |||
| __all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message'] | |||
| class CreatedByRole(Enum): | |||
| @@ -7,6 +7,7 @@ from typing import Optional | |||
| from flask import request | |||
| from flask_login import UserMixin | |||
| from sqlalchemy import Float, func, text | |||
| from sqlalchemy.orm import Mapped, mapped_column | |||
| from configs import dify_config | |||
| from core.file.tool_file_parser import ToolFileParser | |||
| @@ -512,12 +513,12 @@ class Conversation(db.Model): | |||
| from_account_id = db.Column(StringUUID) | |||
| read_at = db.Column(db.DateTime) | |||
| read_account_id = db.Column(StringUUID) | |||
| dialogue_count: Mapped[int] = mapped_column(default=0) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") | |||
| message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', | |||
| passive_deletes="all") | |||
| message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") | |||
| is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) | |||
| @@ -10,8 +10,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.model_providers import ModelProviderFactory | |||
| from core.workflow.entities.node_entities import SystemVariable | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariable | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.nodes.llm.llm_node import LLMNode | |||
| from extensions.ext_database import db | |||
| @@ -236,4 +236,4 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert 'sunny' in json.dumps(result.process_data) | |||
| assert 'what\'s the weather today?' in json.dumps(result.process_data) | |||
| assert 'what\'s the weather today?' in json.dumps(result.process_data) | |||
| @@ -12,8 +12,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.workflow.entities.node_entities import SystemVariable | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariable | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode | |||
| from extensions.ext_database import db | |||
| @@ -363,7 +363,7 @@ def test_extract_json_response(): | |||
| { | |||
| "location": "kawaii" | |||
| } | |||
| hello world. | |||
| hello world. | |||
| """) | |||
| assert result['location'] == 'kawaii' | |||
| @@ -445,4 +445,4 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): | |||
| assert latest_role != prompt.get('role') | |||
| if prompt.get('role') in ['user', 'assistant']: | |||
| latest_role = prompt.get('role') | |||
| latest_role = prompt.get('role') | |||
| @@ -3,12 +3,9 @@ from uuid import uuid4 | |||
| import pytest | |||
| from core.app.segments import ( | |||
| ArrayFileVariable, | |||
| ArrayNumberVariable, | |||
| ArrayObjectVariable, | |||
| ArrayStringVariable, | |||
| FileSegment, | |||
| FileVariable, | |||
| FloatVariable, | |||
| IntegerVariable, | |||
| ObjectSegment, | |||
| @@ -149,83 +146,6 @@ def test_array_object_variable(): | |||
| assert isinstance(variable.value[1]['key2'], int) | |||
| def test_file_variable(): | |||
| mapping = { | |||
| 'id': str(uuid4()), | |||
| 'value_type': 'file', | |||
| 'name': 'test_file', | |||
| 'description': 'Description of the variable.', | |||
| 'value': { | |||
| 'id': str(uuid4()), | |||
| 'tenant_id': 'tenant_id', | |||
| 'type': 'image', | |||
| 'transfer_method': 'local_file', | |||
| 'url': 'url', | |||
| 'related_id': 'related_id', | |||
| 'extra_config': { | |||
| 'image_config': { | |||
| 'width': 100, | |||
| 'height': 100, | |||
| }, | |||
| }, | |||
| 'filename': 'filename', | |||
| 'extension': 'extension', | |||
| 'mime_type': 'mime_type', | |||
| }, | |||
| } | |||
| variable = factory.build_variable_from_mapping(mapping) | |||
| assert isinstance(variable, FileVariable) | |||
| def test_array_file_variable(): | |||
| mapping = { | |||
| 'id': str(uuid4()), | |||
| 'value_type': 'array[file]', | |||
| 'name': 'test_array_file', | |||
| 'description': 'Description of the variable.', | |||
| 'value': [ | |||
| { | |||
| 'id': str(uuid4()), | |||
| 'tenant_id': 'tenant_id', | |||
| 'type': 'image', | |||
| 'transfer_method': 'local_file', | |||
| 'url': 'url', | |||
| 'related_id': 'related_id', | |||
| 'extra_config': { | |||
| 'image_config': { | |||
| 'width': 100, | |||
| 'height': 100, | |||
| }, | |||
| }, | |||
| 'filename': 'filename', | |||
| 'extension': 'extension', | |||
| 'mime_type': 'mime_type', | |||
| }, | |||
| { | |||
| 'id': str(uuid4()), | |||
| 'tenant_id': 'tenant_id', | |||
| 'type': 'image', | |||
| 'transfer_method': 'local_file', | |||
| 'url': 'url', | |||
| 'related_id': 'related_id', | |||
| 'extra_config': { | |||
| 'image_config': { | |||
| 'width': 100, | |||
| 'height': 100, | |||
| }, | |||
| }, | |||
| 'filename': 'filename', | |||
| 'extension': 'extension', | |||
| 'mime_type': 'mime_type', | |||
| }, | |||
| ], | |||
| } | |||
| variable = factory.build_variable_from_mapping(mapping) | |||
| assert isinstance(variable, ArrayFileVariable) | |||
| assert isinstance(variable.value[0], FileSegment) | |||
| assert isinstance(variable.value[1], FileSegment) | |||
| def test_variable_cannot_large_than_5_kb(): | |||
| with pytest.raises(VariableError): | |||
| factory.build_variable_from_mapping( | |||
| @@ -1,7 +1,7 @@ | |||
| from core.app.segments import SecretVariable, StringSegment, parser | |||
| from core.helper import encrypter | |||
| from core.workflow.entities.node_entities import SystemVariable | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariable | |||
| def test_segment_group_to_text(): | |||
| @@ -1,8 +1,8 @@ | |||
| from unittest.mock import MagicMock | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import SystemVariable | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariable | |||
| from core.workflow.nodes.answer.answer_node import AnswerNode | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from extensions.ext_database import db | |||
| @@ -1,8 +1,8 @@ | |||
| from unittest.mock import MagicMock | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import SystemVariable | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariable | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.nodes.if_else.if_else_node import IfElseNode | |||
| from extensions.ext_database import db | |||
| @@ -3,8 +3,8 @@ from uuid import uuid4 | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.segments import ArrayStringVariable, StringVariable | |||
| from core.workflow.entities.node_entities import SystemVariable | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariable | |||
| from core.workflow.nodes.base_node import UserFrom | |||
| from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode | |||