ソースを参照

refactor(variables): replace deprecated 'get_any' with 'get' method (#9584)

tags/0.10.1
-LAN- 1年前
コミット
8f670f31b8
コミッターのメールアドレスに関連付けられたアカウントが存在しません

+ 0
- 21
api/core/workflow/entities/variable_pool.py ファイルの表示

from typing import Any, Union from typing import Any, Union


from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import deprecated


from core.file import File, FileAttribute, file_manager from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable from core.variables import Segment, SegmentGroup, Variable


return value return value


@deprecated("This method is deprecated, use `get` instead.")
def get_any(self, selector: Sequence[str], /) -> Any | None:
"""
Retrieves the value from the variable pool based on the given selector.

Args:
selector (Sequence[str]): The selector used to identify the variable.

Returns:
Any: The value associated with the given selector.

Raises:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
return value.to_object() if value else None

def remove(self, selector: Sequence[str], /): def remove(self, selector: Sequence[str], /):
""" """
Remove variables from the variable pool based on the given selector. Remove variables from the variable pool based on the given selector.

+ 9
- 4
api/core/workflow/nodes/code/code_node.py ファイルの表示

# Get variables # Get variables
variables = {} variables = {}
for variable_selector in self.node_data.variables: for variable_selector in self.node_data.variables:
variable = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)

variables[variable] = value
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=f"Variable `{variable_selector.value_selector}` not found",
)
variables[variable_name] = variable.to_object()
# Run code # Run code
try: try:
result = CodeExecutor.execute_workflow_code_template( result = CodeExecutor.execute_workflow_code_template(

+ 25
- 10
api/core/workflow/nodes/iteration/iteration_node.py ファイルの表示



from configs import dify_config from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.variables import IntegerSegment
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
BaseGraphEvent, BaseGraphEvent,


if NodeRunMetadataKey.ITERATION_ID not in metadata: if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
[self.node_id, "index"]
)
index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(index_variable, IntegerSegment):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Invalid index variable type: {type(index_variable)}",
)
)
return
metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value
event.route_node_state.node_run_result.metadata = metadata event.route_node_state.node_run_result.metadata = metadata


yield event yield event
yield event yield event


# append to iteration output variable list # append to iteration output variable list
current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
current_iteration_output_variable = variable_pool.get(self.node_data.output_selector)
if current_iteration_output_variable is None:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Iteration output variable {self.node_data.output_selector} not found",
)
)
return
current_iteration_output = current_iteration_output_variable.to_object()
outputs.append(current_iteration_output) outputs.append(current_iteration_output)


# remove all nodes outputs from variable pool # remove all nodes outputs from variable pool
variable_pool.remove([node_id]) variable_pool.remove([node_id])


# move to next iteration # move to next iteration
current_index = variable_pool.get([self.node_id, "index"])
if current_index is None:
current_index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(current_index_variable, IntegerSegment):
raise ValueError(f"iteration {self.node_id} current index not found") raise ValueError(f"iteration {self.node_id} current index not found")


next_index = int(current_index.to_object()) + 1
next_index = current_index_variable.value + 1
variable_pool.add([self.node_id, "index"], next_index) variable_pool.add([self.node_id, "index"], next_index)


if next_index < len(iterator_list_value): if next_index < len(iterator_list_value):
iteration_node_type=self.node_type, iteration_node_type=self.node_type,
iteration_node_data=self.node_data, iteration_node_data=self.node_data,
index=next_index, index=next_index,
pre_iteration_output=jsonable_encoder(current_iteration_output)
if current_iteration_output
else None,
pre_iteration_output=jsonable_encoder(current_iteration_output),
) )


yield IterationRunSucceededEvent( yield IterationRunSucceededEvent(

+ 9
- 2
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py ファイルの表示

from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType


def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# extract variables # extract variables
variable = self.graph_runtime_state.variable_pool.get_any(self.node_data.query_variable_selector)
query = variable
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Query variable is not string type.",
)
query = variable.value
variables = {"query": query} variables = {"query": query}
if not query: if not query:
return NodeRunResult( return NodeRunResult(

+ 34
- 31
api/core/workflow/nodes/llm/node.py ファイルの表示

from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
from core.variables import (
ArrayAnySegment,
ArrayFileSegment,
ArraySegment,
FileSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
return variables return variables


for variable_selector in node_data.prompt_config.jinja2_variables or []: for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")


def parse_dict(d: dict) -> str:
def parse_dict(input_dict: Mapping[str, Any]) -> str:
""" """
Parse dict into string Parse dict into string
""" """
# check if it's a context structure # check if it's a context structure
if "metadata" in d and "_source" in d["metadata"] and "content" in d:
return d["content"]
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
return input_dict["content"]


# else, parse the dict # else, parse the dict
try: try:
return json.dumps(d, ensure_ascii=False)
return json.dumps(input_dict, ensure_ascii=False)
except Exception: except Exception:
return str(d)
return str(input_dict)


if isinstance(value, str):
value = value
elif isinstance(value, list):
if isinstance(variable, ArraySegment):
result = "" result = ""
for item in value:
for item in variable.value:
if isinstance(item, dict): if isinstance(item, dict):
result += parse_dict(item) result += parse_dict(item)
elif isinstance(item, str):
result += item
elif isinstance(item, int | float):
result += str(item)
else: else:
result += str(item) result += str(item)
result += "\n" result += "\n"
value = result.strip() value = result.strip()
elif isinstance(value, dict):
value = parse_dict(value)
elif isinstance(value, int | float):
value = str(value)
elif isinstance(variable, ObjectSegment):
value = parse_dict(variable.value)
else: else:
value = str(value)
value = variable.text


variables[variable] = value
variables[variable_name] = value


return variables return variables


def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]:
inputs = {} inputs = {}
prompt_template = node_data.prompt_template prompt_template = node_data.prompt_template


if not node_data.context.variable_selector: if not node_data.context.variable_selector:
return return


context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
elif isinstance(context_value, list):
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
if context_value_variable:
if isinstance(context_value_variable, StringSegment):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
elif isinstance(context_value_variable, ArraySegment):
context_str = "" context_str = ""
original_retriever_resource = [] original_retriever_resource = []
for item in context_value:
for item in context_value_variable.value:
if isinstance(item, str): if isinstance(item, str):
context_str += item + "\n" context_str += item + "\n"
else: else:
return None return None


# get conversation id # get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get_any(
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value] ["sys", SystemVariableKey.CONVERSATION_ID.value]
) )
if conversation_id is None:
if not isinstance(conversation_id_variable, StringSegment):
return None return None
conversation_id = conversation_id_variable.value


# get conversation # get conversation
conversation = ( conversation = (

+ 7
- 2
api/core/workflow/nodes/template_transform/template_transform_node.py ファイルの表示

variables = {} variables = {}
for variable_selector in self.node_data.variables: for variable_selector in self.node_data.variables:
variable_name = variable_selector.variable variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
variables[variable_name] = value
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if value is None:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Variable {variable_name} not found in variable pool",
)
variables[variable_name] = value.to_object()
# Run code # Run code
try: try:
result = CodeExecutor.execute_workflow_code_template( result = CodeExecutor.execute_workflow_code_template(

+ 7
- 7
api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py ファイルの表示



if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables: for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get_any(selector)
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None: if variable is not None:
outputs = {"output": variable}
outputs = {"output": variable.to_object()}


inputs = {".".join(selector[1:]): variable}
inputs = {".".join(selector[1:]): variable.to_object()}
break break
else: else:
for group in self.node_data.advanced_settings.groups: for group in self.node_data.advanced_settings.groups:
for selector in group.variables: for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get_any(selector)
variable = self.graph_runtime_state.variable_pool.get(selector)


if variable is not None: if variable is not None:
outputs[group.group_name] = {"output": variable}
inputs[".".join(selector[1:])] = variable
outputs[group.group_name] = {"output": variable.to_object()}
inputs[".".join(selector[1:])] = variable.to_object()
break break


return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs)


@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" """
Extract variable selector to variable mapping Extract variable selector to variable mapping

+ 4
- 0
api/tests/integration_tests/workflow/nodes/test_code.py ファイルの表示

} }


node = init_code_node(code_config) node = init_code_node(code_config)
node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1)
node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2)


# execute node # execute node
result = node._run() result = node._run()
} }


node = init_code_node(code_config) node = init_code_node(code_config)
node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1)
node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2)


# execute node # execute node
result = node._run() result = node._run()

読み込み中…
キャンセル
保存