Browse Source

refactor(api/core/workflow/enums.py): Rename SystemVariable to SystemVariableKey. (#7445)

tags/0.7.2
-LAN- 1 year ago
parent
commit
4f5f27cf2b
No account linked to committer's email address

+ 16
- 16
api/core/app/apps/advanced_chat/app_generator.py View File

@@ -29,7 +29,7 @@ from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
@@ -46,7 +46,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[dict, None, None]]:
):
"""
Generate App response.

@@ -73,8 +73,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):

# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)

# parse files
files = args['files'] if args.get('files') else []
@@ -133,8 +134,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
stream: bool = True):
"""
Generate App response.

@@ -157,8 +157,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):

# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)

# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
@@ -200,8 +201,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation | None = None,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
stream: bool = True):
is_first_conversation = False
if not conversation:
is_first_conversation = True
@@ -270,11 +270,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):

# Create a variable pool.
system_inputs = {
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation_id,
SystemVariable.USER_ID: user_id,
SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count,
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
@@ -362,7 +362,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

+ 6
- 6
api/core/app/apps/advanced_chat/generate_task_pipeline.py View File

@@ -49,7 +49,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created
@@ -74,7 +74,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]

def __init__(
@@ -108,10 +108,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id,
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_id,
}

self._task_state = AdvancedChatTaskState(

+ 3
- 3
api/core/app/apps/workflow/app_runner.py View File

@@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
@@ -67,8 +67,8 @@ class WorkflowAppRunner:

# Create a variable pool.
system_inputs = {
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id,
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,

+ 4
- 4
api/core/app/apps/workflow/generate_task_pipeline.py View File

@@ -43,7 +43,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
@@ -67,7 +67,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]

def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
@@ -92,8 +92,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa

self._workflow = workflow
self._workflow_system_variables = {
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.USER_ID: user_id
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id
}

self._task_state = WorkflowTaskState(

+ 2
- 2
api/core/app/task_pipeline/workflow_cycle_state_manager.py View File

@@ -2,7 +2,7 @@ from typing import Any, Union

from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
@@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariableKey, Any]

+ 9
- 9
api/core/workflow/entities/variable_pool.py View File

@@ -6,20 +6,20 @@ from typing_extensions import deprecated

from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey

VariableValue = Union[str, int, float, dict, list, FileVar]


SYSTEM_VARIABLE_NODE_ID = 'sys'
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"


class VariablePool:
def __init__(
self,
system_variables: Mapping[SystemVariable, Any],
system_variables: Mapping[SystemVariableKey, Any],
user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable] | None = None,
@@ -68,7 +68,7 @@ class VariablePool:
None
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")

if value is None:
return
@@ -95,13 +95,13 @@ class VariablePool:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)

return value

@deprecated('This method is deprecated, use `get` instead.')
@deprecated("This method is deprecated, use `get` instead.")
def get_any(self, selector: Sequence[str], /) -> Any | None:
"""
Retrieves the value from the variable pool based on the given selector.
@@ -116,7 +116,7 @@ class VariablePool:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
return value.to_object() if value else None

+ 6
- 18
api/core/workflow/enums.py View File

@@ -1,25 +1,13 @@
from enum import Enum


class SystemVariable(str, Enum):
class SystemVariableKey(str, Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'

@classmethod
def value_of(cls, value: str):
"""
Get value of given system variable.

:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')
QUERY = "query"
FILES = "files"
CONVERSATION_ID = "conversation_id"
USER_ID = "user_id"
DIALOGUE_COUNT = "dialogue_count"

+ 6
- 6
api/core/workflow/nodes/llm/llm_node.py View File

@@ -24,7 +24,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
@@ -94,7 +94,7 @@ class LLMNode(BaseNode):
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value])
query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs,
@@ -335,7 +335,7 @@ class LLMNode(BaseNode):
if not node_data.vision.enabled:
return []

files = variable_pool.get_any(['sys', SystemVariable.FILES.value])
files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
if not files:
return []

@@ -500,7 +500,7 @@ class LLMNode(BaseNode):
return None

# get conversation id
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value])
conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
if conversation_id is None:
return None

@@ -672,10 +672,10 @@ class LLMNode(BaseNode):
variable_mapping['#context#'] = node_data.context.variable_selector

if node_data.vision.enabled:
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]

if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]

if node_data.prompt_config:
enable_jinja = False

+ 7
- 7
api/core/workflow/nodes/start/start_node.py View File

@@ -1,7 +1,7 @@

from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
@@ -17,16 +17,16 @@ class StartNode(BaseNode):
:param variable_pool: variable pool
:return:
"""
# Get cleaned inputs
cleaned_inputs = dict(variable_pool.user_inputs)
node_inputs = dict(variable_pool.user_inputs)
system_inputs = variable_pool.system_variables

for var in variable_pool.system_variables:
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]

return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=cleaned_inputs,
outputs=cleaned_inputs
inputs=node_inputs,
outputs=node_inputs
)

@classmethod

+ 2
- 2
api/core/workflow/nodes/tool/tool_node.py View File

@@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@@ -141,7 +141,7 @@ class ToolNode(BaseNode):
return result

def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable)
return list(variable.value) if variable else []


+ 9
- 9
api/tests/integration_tests/workflow/nodes/test_llm.py View File

@@ -11,7 +11,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import ModelProviderFactory
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.llm.llm_node import LLMNode
from extensions.ext_database import db
@@ -66,10 +66,10 @@ def test_execute_llm(setup_openai_mock):

# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather today?',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather today?',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['abc', 'output'], 'sunny')

@@ -181,10 +181,10 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):

# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather today?',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather today?',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['abc', 'output'], 'sunny')


+ 21
- 21
api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py View File

@@ -13,7 +13,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from extensions.ext_database import db
@@ -119,10 +119,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock):

# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])

result = node.run(pool)
@@ -177,10 +177,10 @@ def test_instructions(setup_openai_mock):

# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])

result = node.run(pool)
@@ -243,10 +243,10 @@ def test_chat_parameter_extractor(setup_anthropic_mock):

# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])

result = node.run(pool)
@@ -307,10 +307,10 @@ def test_completion_parameter_extractor(setup_openai_mock):

# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])

result = node.run(pool)
@@ -420,10 +420,10 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):

# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.QUERY: 'what\'s the weather in SF',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])

result = node.run(pool)

+ 3
- 3
api/tests/unit_tests/core/app/segments/test_segment.py View File

@@ -1,13 +1,13 @@
from core.app.segments import SecretVariable, StringSegment, parser
from core.helper import encrypter
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey


def test_segment_group_to_text():
variable_pool = VariablePool(
system_variables={
SystemVariable('user_id'): 'fake-user-id',
SystemVariableKey('user_id'): 'fake-user-id',
},
user_inputs={},
environment_variables=[
@@ -42,7 +42,7 @@ def test_convert_constant_to_segment_group():
def test_convert_variable_to_segment_group():
variable_pool = VariablePool(
system_variables={
SystemVariable('user_id'): 'fake-user-id',
SystemVariableKey('user_id'): 'fake-user-id',
},
user_inputs={},
environment_variables=[],

+ 3
- 3
api/tests/unit_tests/core/workflow/nodes/test_answer.py View File

@@ -2,7 +2,7 @@ from unittest.mock import MagicMock

from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import UserFrom
from extensions.ext_database import db
@@ -29,8 +29,8 @@ def test_execute_answer():

# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'weather'], 'sunny')
pool.add(['llm', 'text'], 'You are a helpful AI.')

+ 5
- 5
api/tests/unit_tests/core/workflow/nodes/test_if_else.py View File

@@ -2,7 +2,7 @@ from unittest.mock import MagicMock

from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from extensions.ext_database import db
@@ -119,8 +119,8 @@ def test_execute_if_else_result_true():

# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
@@ -182,8 +182,8 @@ def test_execute_if_else_result_false():

# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['1ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ab', 'def'])

+ 4
- 4
api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py View File

@@ -4,7 +4,7 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import ArrayStringVariable, StringVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode

@@ -42,7 +42,7 @@ def test_overwrite_string_variable():
)

variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -93,7 +93,7 @@ def test_append_variable_to_array():
)

variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -137,7 +137,7 @@ def test_clear_array():
)

variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],

Loading…
Cancel
Save