Ver código fonte

feat(api/workflow): Add `Conversation.dialogue_count` (#7275)

tags/0.7.1
-LAN- 1 ano atrás
pai
commit
32dc963556
Nenhuma conta vinculada ao e-mail do autor do commit
29 arquivos alterados com 205 adições e 259 exclusões
  1. 5
    1
      api/contexts/__init__.py
  2. 84
    24
      api/core/app/apps/advanced_chat/app_generator.py
  3. 2
    51
      api/core/app/apps/advanced_chat/app_runner.py
  4. 10
    4
      api/core/app/apps/advanced_chat/generate_task_pipeline.py
  5. 4
    1
      api/core/app/apps/message_based_app_generator.py
  6. 1
    1
      api/core/app/apps/workflow/app_runner.py
  7. 3
    3
      api/core/app/apps/workflow/generate_task_pipeline.py
  8. 0
    6
      api/core/app/segments/__init__.py
  9. 0
    12
      api/core/app/segments/factory.py
  10. 0
    13
      api/core/app/segments/segments.py
  11. 0
    2
      api/core/app/segments/types.py
  12. 0
    9
      api/core/app/segments/variables.py
  13. 2
    2
      api/core/app/task_pipeline/workflow_cycle_state_manager.py
  14. 4
    24
      api/core/workflow/entities/node_entities.py
  15. 1
    1
      api/core/workflow/entities/variable_pool.py
  16. 25
    0
      api/core/workflow/enums.py
  17. 6
    5
      api/core/workflow/nodes/llm/llm_node.py
  18. 6
    5
      api/core/workflow/nodes/tool/tool_node.py
  19. 4
    1
      api/core/workflow/workflow_engine_manager.py
  20. 33
    0
      api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py
  21. 3
    3
      api/models/__init__.py
  22. 3
    2
      api/models/model.py
  23. 2
    2
      api/tests/integration_tests/workflow/nodes/test_llm.py
  24. 3
    3
      api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
  25. 0
    80
      api/tests/unit_tests/core/app/segments/test_factory.py
  26. 1
    1
      api/tests/unit_tests/core/app/segments/test_segment.py
  27. 1
    1
      api/tests/unit_tests/core/workflow/nodes/test_answer.py
  28. 1
    1
      api/tests/unit_tests/core/workflow/nodes/test_if_else.py
  29. 1
    1
      api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py

+ 5
- 1
api/contexts/__init__.py Ver arquivo

@@ -1,3 +1,7 @@
from contextvars import ContextVar

tenant_id: ContextVar[str] = ContextVar('tenant_id')
from core.workflow.entities.variable_pool import VariablePool

tenant_id: ContextVar[str] = ContextVar('tenant_id')

workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool')

+ 84
- 24
api/core/app/apps/advanced_chat/app_generator.py Ver arquivo

@@ -8,6 +8,8 @@ from typing import Union

from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session

import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@@ -18,15 +20,20 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
from models.workflow import ConversationVariable, Workflow

logger = logging.getLogger(__name__)

@@ -120,7 +127,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
stream=stream
)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
@@ -140,10 +147,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"""
if not node_id:
raise ValueError('node_id is required')
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
@@ -209,7 +216,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
db.session.refresh(conversation)
# db.session.refresh(conversation)

# init queue manager
queue_manager = MessageBasedAppQueueManager(
@@ -221,15 +228,69 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id
)

# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]

session.commit()

# Increment dialogue count.
conversation.dialogue_count += 1

conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)

inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files

user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id

# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation_id,
SystemVariable.USER_ID: user_id,
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)

# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'user': user,
'context': contextvars.copy_context()
'context': contextvars.copy_context(),
})

worker_thread.start()
@@ -242,7 +303,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream
stream=stream,
)

return AdvancedChatAppGenerateResponseConverter.convert(
@@ -253,9 +314,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
user: Account,
context: contextvars.Context) -> None:
"""
Generate worker in a new thread.
@@ -282,8 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user_id=application_generate_entity.user_id
)
else:
# get conversation and message
conversation = self._get_conversation(conversation_id)
# get message
message = self._get_message(message_id)

# chatbot app
@@ -291,7 +349,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except GenerateTaskStoppedException:
@@ -314,14 +371,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
finally:
db.session.close()

def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False) \
-> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
def _handle_advanced_chat_response(
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
@@ -341,7 +401,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream
stream=stream,
)

try:

+ 2
- 51
api/core/app/apps/advanced_chat/app_runner.py Ver arquivo

@@ -4,9 +4,6 @@ import time
from collections.abc import Mapping
from typing import Any, Optional, cast

from sqlalchemy import select
from sqlalchemy.orm import Session

from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -19,13 +16,10 @@ from core.app.entities.app_invoke_entities import (
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, Workflow
from models import App, Message, Workflow

logger = logging.getLogger(__name__)

@@ -39,7 +33,6 @@ class AdvancedChatAppRunner(AppRunner):
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
"""
@@ -63,15 +56,6 @@ class AdvancedChatAppRunner(AppRunner):

inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files

user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id

# moderation
if self.handle_input_moderation(
@@ -103,38 +87,6 @@ class AdvancedChatAppRunner(AppRunner):
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())

# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
session.commit()
# Convert database entities to variables
conversation_variables = [item.to_variable() for item in conversation_variables]

# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)

# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
@@ -146,7 +98,6 @@ class AdvancedChatAppRunner(AppRunner):
invoke_from=application_generate_entity.invoke_from,
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
variable_pool=variable_pool,
)

def single_iteration_run(
@@ -155,7 +106,7 @@ class AdvancedChatAppRunner(AppRunner):
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')


+ 10
- 4
api/core/app/apps/advanced_chat/generate_task_pipeline.py Ver arquivo

@@ -4,6 +4,7 @@ import time
from collections.abc import Generator
from typing import Any, Optional, Union, cast

import contexts
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -47,7 +48,8 @@ from core.file.file_obj import FileVar
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created
@@ -71,6 +73,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariable, Any]
_iteration_nested_relations: dict[str, list[str]]

@@ -81,7 +84,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool
stream: bool,
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
@@ -103,11 +106,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._workflow = workflow
self._conversation = conversation
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
SystemVariable.USER_ID: user_id,
}

self._task_state = AdvancedChatTaskState(
@@ -613,7 +617,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc

if route_chunk_node_id == 'sys':
# system variable
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
value = contexts.workflow_variable_pool.get().get(value_selector)
if value:
value = value.text
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:

+ 4
- 1
api/core/app/apps/message_based_app_generator.py Ver arquivo

@@ -258,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):

return introduction

def _get_conversation(self, conversation_id: str) -> Conversation:
def _get_conversation(self, conversation_id: str):
"""
Get conversation by conversation id
:param conversation_id: conversation id
@@ -270,6 +270,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
.first()
)

if not conversation:
raise ConversationNotExistsError()

return conversation

def _get_message(self, message_id: str) -> Message:

+ 1
- 1
api/core/app/apps/workflow/app_runner.py Ver arquivo

@@ -11,8 +11,8 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db

+ 3
- 3
api/core/app/apps/workflow/generate_task_pipeline.py Ver arquivo

@@ -42,7 +42,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
@@ -519,7 +520,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
"""
nodes = graph.get('nodes')

iteration_ids = [node.get('id') for node in nodes
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
@@ -530,4 +531,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}

+ 0
- 6
api/core/app/segments/__init__.py Ver arquivo

@@ -2,7 +2,6 @@ from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@@ -13,11 +12,9 @@ from .segments import (
from .types import SegmentType
from .variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
@@ -32,7 +29,6 @@ __all__ = [
'FloatVariable',
'ObjectVariable',
'SecretVariable',
'FileVariable',
'StringVariable',
'ArrayAnyVariable',
'Variable',
@@ -45,11 +41,9 @@ __all__ = [
'FloatSegment',
'ObjectSegment',
'ArrayAnySegment',
'FileSegment',
'StringSegment',
'ArrayStringVariable',
'ArrayNumberVariable',
'ArrayObjectVariable',
'ArrayFileVariable',
'ArraySegment',
]

+ 0
- 12
api/core/app/segments/factory.py Ver arquivo

@@ -2,12 +2,10 @@ from collections.abc import Mapping
from typing import Any

from configs import dify_config
from core.file.file_obj import FileVar

from .exc import VariableError
from .segments import (
ArrayAnySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@@ -17,11 +15,9 @@ from .segments import (
)
from .types import SegmentType
from .variables import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
@@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f'invalid number value {value}')
case SegmentType.FILE:
result = FileVariable.model_validate(mapping)
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
@@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_FILE if isinstance(value, list):
mapping = dict(mapping)
mapping['value'] = [{'value': v} for v in value]
result = ArrayFileVariable.model_validate(mapping)
case _:
raise VariableError(f'not supported value type {value_type}')
if result.size > dify_config.MAX_VARIABLE_SIZE:
@@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment:
return ObjectSegment(value=value)
if isinstance(value, list):
return ArrayAnySegment(value=value)
if isinstance(value, FileVar):
return FileSegment(value=value)
raise ValueError(f'not supported value {value}')

+ 0
- 13
api/core/app/segments/segments.py Ver arquivo

@@ -5,8 +5,6 @@ from typing import Any

from pydantic import BaseModel, ConfigDict, field_validator

from core.file.file_obj import FileVar

from .types import SegmentType


@@ -78,14 +76,7 @@ class IntegerSegment(Segment):
value: int


class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
# TODO: embed FileVar in this model.
value: FileVar

@property
def markdown(self) -> str:
return self.value.to_markdown()


class ObjectSegment(Segment):
@@ -130,7 +121,3 @@ class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]]


class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[FileSegment]

+ 0
- 2
api/core/app/segments/types.py Ver arquivo

@@ -10,8 +10,6 @@ class SegmentType(str, Enum):
ARRAY_STRING = 'array[string]'
ARRAY_NUMBER = 'array[number]'
ARRAY_OBJECT = 'array[object]'
ARRAY_FILE = 'array[file]'
OBJECT = 'object'
FILE = 'file'

GROUP = 'group'

+ 0
- 9
api/core/app/segments/variables.py Ver arquivo

@@ -4,11 +4,9 @@ from core.helper import encrypter

from .segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
@@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable):
pass


class FileVariable(FileSegment, Variable):
pass


class ObjectVariable(ObjectSegment, Variable):
pass

@@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
pass


class ArrayFileVariable(ArrayFileSegment, Variable):
pass


class SecretVariable(StringVariable):
value_type: SegmentType = SegmentType.SECRET

+ 2
- 2
api/core/app/task_pipeline/workflow_cycle_state_manager.py Ver arquivo

@@ -2,7 +2,7 @@ from typing import Any, Union

from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.enums import SystemVariable
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
@@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariable, Any]

+ 4
- 24
api/core/workflow/entities/node_entities.py Ver arquivo

@@ -4,13 +4,14 @@ from typing import Any, Optional

from pydantic import BaseModel

from models.workflow import WorkflowNodeExecutionStatus
from models import WorkflowNodeExecutionStatus


class NodeType(Enum):
"""
Node Types.
"""

START = 'start'
END = 'end'
ANSWER = 'answer'
@@ -44,33 +45,11 @@ class NodeType(Enum):
raise ValueError(f'invalid node type value {value}')


class SystemVariable(Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'

@classmethod
def value_of(cls, value: str) -> 'SystemVariable':
"""
Get value of given system variable.

:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')


class NodeRunMetadataKey(Enum):
"""
Node Run Metadata Key.
"""

TOTAL_TOKENS = 'total_tokens'
TOTAL_PRICE = 'total_price'
CURRENCY = 'currency'
@@ -83,6 +62,7 @@ class NodeRunResult(BaseModel):
"""
Node Run Result.
"""

status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING

inputs: Optional[Mapping[str, Any]] = None # node inputs

+ 1
- 1
api/core/workflow/entities/variable_pool.py Ver arquivo

@@ -6,7 +6,7 @@ from typing_extensions import deprecated

from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.enums import SystemVariable

VariableValue = Union[str, int, float, dict, list, FileVar]


+ 25
- 0
api/core/workflow/enums.py Ver arquivo

@@ -0,0 +1,25 @@
from enum import Enum


class SystemVariable(str, Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'

@classmethod
def value_of(cls, value: str):
"""
Get value of given system variable.

:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')

+ 6
- 5
api/core/workflow/nodes/llm/llm_node.py Ver arquivo

@@ -23,8 +23,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
@@ -201,8 +202,8 @@ class LLMNode(BaseNode):
usage = LLMUsage.empty_usage()

return full_text, usage
def _transform_chat_messages(self,
def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
@@ -249,13 +250,13 @@ class LLMNode(BaseNode):
# check if it's a context structure
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
return d['content']
# else, parse the dict
try:
return json.dumps(d, ensure_ascii=False)
except Exception:
return str(d)
if isinstance(value, str):
value = value
elif isinstance(value, list):

+ 6
- 5
api/core/workflow/nodes/tool/tool_node.py Ver arquivo

@@ -2,19 +2,20 @@ from collections.abc import Mapping, Sequence
from os import path
from typing import Any, cast

from core.app.segments import parser
from core.app.segments import ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
from models import WorkflowNodeExecutionStatus


class ToolNode(BaseNode):
@@ -140,9 +141,9 @@ class ToolNode(BaseNode):
return result

def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
# FIXME: ensure this is a ArrayVariable contains FileVariable.
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
return [file_var.value for file_var in variable.value] if variable else []
assert isinstance(variable, ArrayAnyVariable)
return list(variable.value) if variable else []

def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
"""

+ 4
- 1
api/core/workflow/workflow_engine_manager.py Ver arquivo

@@ -3,6 +3,7 @@ import time
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast

import contexts
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -97,7 +98,7 @@ class WorkflowEngineManager:
invoke_from: InvokeFrom,
callbacks: Sequence[WorkflowCallback],
call_depth: int = 0,
variable_pool: VariablePool,
variable_pool: VariablePool | None = None,
) -> None:
"""
:param workflow: Workflow instance
@@ -128,6 +129,8 @@ class WorkflowEngineManager:
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))

# init workflow run state
if not variable_pool:
variable_pool = contexts.workflow_variable_pool.get()
workflow_run_state = WorkflowRunState(
workflow=workflow,
start_at=time.perf_counter(),

+ 33
- 0
api/migrations/versions/2024_08_14_1354-8782057ff0dc_add_conversations_dialogue_count.py Ver arquivo

@@ -0,0 +1,33 @@
"""add conversations.dialogue_count

Revision ID: 8782057ff0dc
Revises: 63a83fcf12ba
Create Date: 2024-08-14 13:54:25.161324

"""
import sqlalchemy as sa
from alembic import op

import models as models

# revision identifiers, used by Alembic.
revision = '8782057ff0dc'
down_revision = '63a83fcf12ba'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.add_column(sa.Column('dialogue_count', sa.Integer(), server_default='0', nullable=False))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.drop_column('dialogue_count')

# ### end Alembic commands ###

+ 3
- 3
api/models/__init__.py Ver arquivo

@@ -1,10 +1,10 @@
from enum import Enum

from .model import AppMode
from .model import App, AppMode, Message
from .types import StringUUID
from .workflow import ConversationVariable, WorkflowNodeExecutionStatus
from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus

__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus']
__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message']


class CreatedByRole(Enum):

+ 3
- 2
api/models/model.py Ver arquivo

@@ -7,6 +7,7 @@ from typing import Optional
from flask import request
from flask_login import UserMixin
from sqlalchemy import Float, func, text
from sqlalchemy.orm import Mapped, mapped_column

from configs import dify_config
from core.file.tool_file_parser import ToolFileParser
@@ -512,12 +513,12 @@ class Conversation(db.Model):
from_account_id = db.Column(StringUUID)
read_at = db.Column(db.DateTime)
read_account_id = db.Column(StringUUID)
dialogue_count: Mapped[int] = mapped_column(default=0)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select',
passive_deletes="all")
message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")

is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))


+ 2
- 2
api/tests/integration_tests/workflow/nodes/test_llm.py Ver arquivo

@@ -10,8 +10,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import ModelProviderFactory
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.llm.llm_node import LLMNode
from extensions.ext_database import db
@@ -236,4 +236,4 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):

assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert 'sunny' in json.dumps(result.process_data)
assert 'what\'s the weather today?' in json.dumps(result.process_data)
assert 'what\'s the weather today?' in json.dumps(result.process_data)

+ 3
- 3
api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py Ver arquivo

@@ -12,8 +12,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from extensions.ext_database import db
@@ -363,7 +363,7 @@ def test_extract_json_response():
{
"location": "kawaii"
}
hello world.
hello world.
""")

assert result['location'] == 'kawaii'
@@ -445,4 +445,4 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
assert latest_role != prompt.get('role')

if prompt.get('role') in ['user', 'assistant']:
latest_role = prompt.get('role')
latest_role = prompt.get('role')

+ 0
- 80
api/tests/unit_tests/core/app/segments/test_factory.py Ver arquivo

@@ -3,12 +3,9 @@ from uuid import uuid4
import pytest

from core.app.segments import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileSegment,
FileVariable,
FloatVariable,
IntegerVariable,
ObjectSegment,
@@ -149,83 +146,6 @@ def test_array_object_variable():
assert isinstance(variable.value[1]['key2'], int)


def test_file_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'file',
'name': 'test_file',
'description': 'Description of the variable.',
'value': {
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, FileVariable)


def test_array_file_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'array[file]',
'name': 'test_array_file',
'description': 'Description of the variable.',
'value': [
{
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
{
'id': str(uuid4()),
'tenant_id': 'tenant_id',
'type': 'image',
'transfer_method': 'local_file',
'url': 'url',
'related_id': 'related_id',
'extra_config': {
'image_config': {
'width': 100,
'height': 100,
},
},
'filename': 'filename',
'extension': 'extension',
'mime_type': 'mime_type',
},
],
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayFileVariable)
assert isinstance(variable.value[0], FileSegment)
assert isinstance(variable.value[1], FileSegment)


def test_variable_cannot_large_than_5_kb():
with pytest.raises(VariableError):
factory.build_variable_from_mapping(

+ 1
- 1
api/tests/unit_tests/core/app/segments/test_segment.py Ver arquivo

@@ -1,7 +1,7 @@
from core.app.segments import SecretVariable, StringSegment, parser
from core.helper import encrypter
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable


def test_segment_group_to_text():

+ 1
- 1
api/tests/unit_tests/core/workflow/nodes/test_answer.py Ver arquivo

@@ -1,8 +1,8 @@
from unittest.mock import MagicMock

from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import UserFrom
from extensions.ext_database import db

+ 1
- 1
api/tests/unit_tests/core/workflow/nodes/test_if_else.py Ver arquivo

@@ -1,8 +1,8 @@
from unittest.mock import MagicMock

from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from extensions.ext_database import db

+ 1
- 1
api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py Ver arquivo

@@ -3,8 +3,8 @@ from uuid import uuid4

from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import ArrayStringVariable, StringVariable
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode


Carregando…
Cancelar
Salvar