浏览代码

refactor(api): Decouple `ParameterExtractorNode` from `LLMNode` (#20843)

- Extract methods used by `ParameterExtractorNode` from `LLMNode` into a separate file.
- Convert `ParameterExtractorNode` into a subclass of `BaseNode`.
- Refactor code referencing the extracted methods to ensure functionality and clarity.
- Fixes the issue that `ParameterExtractorNode` returns error when executed.
- Fix relevant test cases.

Closes #20840.
tags/1.4.2
QuantumGhost 4 个月前
父节点
当前提交
c439e82038
没有帐户链接到提交者的电子邮件

+ 3
- 3
api/core/plugin/backwards_invocation/model.py 查看文件

@@ -21,7 +21,7 @@ from core.plugin.entities.request import (
)
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm import llm_utils
from models.account import Tenant


@@ -55,7 +55,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response:
if chunk.delta.usage:
LLMNode.deduct_llm_quota(
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
chunk.prompt_messages = []
@@ -64,7 +64,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
return handle()
else:
if response.usage:
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)

def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(

+ 2
- 2
api/core/rag/retrieval/router/multi_dataset_react_route.py 查看文件

@@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm import LLMNode
from core.workflow.nodes.llm import llm_utils

PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""

@@ -165,7 +165,7 @@ class ReactMultiDatasetRouter:
text, usage = self._handle_invoke_result(invoke_result=invoke_result)

# deduct quota
LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)

return text, usage


+ 156
- 0
api/core/workflow/nodes/llm/llm_utils.py 查看文件

@@ -0,0 +1,156 @@
from collections.abc import Sequence
from datetime import UTC, datetime
from typing import Optional, cast

from sqlalchemy import select, update
from sqlalchemy.orm import Session

from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import QuotaUnit
from core.file.models import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
from models import db
from models.model import Conversation
from models.provider import Provider, ProviderType

from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError


def fetch_model_config(
tenant_id: str, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")

model = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=node_data_model.provider,
model=node_data_model.name,
)

model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)

# check model
provider_model = model.provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name, model_type=ModelType.LLM
)

if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()

# model config
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")

model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")

return model, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=model.provider_model_bundle,
credentials=model.credentials,
parameters=node_data_model.completion_params,
stop=stop,
)


def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
variable = variable_pool.get(selector)
if variable is None:
return []
elif isinstance(variable, FileSegment):
return [variable.value]
elif isinstance(variable, ArrayFileSegment):
return variable.value
elif isinstance(variable, NoneSegment | ArrayAnySegment):
return []
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")


def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
if not node_data_memory:
return None

# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value

with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None

memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory


def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration

if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return

system_configuration = provider_configuration.system_configuration

quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit

if quota_configuration.quota_limit == -1:
return

break

used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model)
else:
used_quota = 1

if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
)
)
session.execute(stmt)
session.commit()

+ 22
- 142
api/core/workflow/nodes/llm/node.py 查看文件

@@ -3,16 +3,11 @@ import io
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast

import json_repair
from sqlalchemy import select, update
from sqlalchemy.orm import Session

from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import QuotaUnit
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -40,12 +35,10 @@ from core.model_runtime.entities.model_entities import (
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.variables import (
ArrayAnySegment,
ArrayFileSegment,
ArraySegment,
FileSegment,
@@ -75,10 +68,8 @@ from core.workflow.utils.structured_output.entities import (
)
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType

from . import llm_utils
from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
@@ -88,7 +79,6 @@ from .entities import (
from .exc import (
InvalidContextStructureError,
InvalidVariableTypeError,
LLMModeRequiredError,
LLMNodeError,
MemoryRolePrefixRequiredError,
ModelNotExistError,
@@ -160,6 +150,7 @@ class LLMNode(BaseNode[LLMNodeData]):
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
variable_pool = self.graph_runtime_state.variable_pool

try:
# init messages template
@@ -178,7 +169,10 @@ class LLMNode(BaseNode[LLMNodeData]):

# fetch files
files = (
self._fetch_files(selector=self.node_data.vision.configs.variable_selector)
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=self.node_data.vision.configs.variable_selector,
)
if self.node_data.vision.enabled
else []
)
@@ -200,15 +194,18 @@ class LLMNode(BaseNode[LLMNodeData]):
model_instance, model_config = self._fetch_model_config(self.node_data.model)

# fetch memory
memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance)
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
)

query = None
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
if not query and (
query_variable := self.graph_runtime_state.variable_pool.get(
(SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)
)
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
query = query_variable.text

@@ -222,7 +219,7 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
)

@@ -251,7 +248,7 @@ class LLMNode(BaseNode[LLMNodeData]):
usage = event.usage
finish_reason = event.finish_reason
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
structured_output = process_structured_output(result_text)
@@ -447,18 +444,6 @@ class LLMNode(BaseNode[LLMNodeData]):

return inputs

def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is None:
return []
elif isinstance(variable, FileSegment):
return [variable.value]
elif isinstance(variable, ArrayFileSegment):
return variable.value
elif isinstance(variable, NoneSegment | ArrayAnySegment):
return []
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")

def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled:
return
@@ -524,31 +509,10 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")

model = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=node_data_model.provider,
model=node_data_model.name,
model, model_config_with_cred = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model
)

model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)

# check model
provider_model = model.provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name, model_type=ModelType.LLM
)

if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()

# model config
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
completion_params = model_config_with_cred.parameters

model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
@@ -556,47 +520,12 @@ class LLMNode(BaseNode[LLMNodeData]):

if self.node_data.structured_output_enabled:
if model_schema.support_structure_output:
node_data_model.completion_params = self._handle_native_json_schema(
node_data_model.completion_params, model_schema.parameter_rules
)
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
else:
# Set appropriate response format based on model capabilities
self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules)

return model, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=model.provider_model_bundle,
credentials=model.credentials,
parameters=node_data_model.completion_params,
stop=stop,
)

def _fetch_memory(
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
if not node_data_memory:
return None

# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value

with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None

memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)

return memory
self._set_response_format(completion_params, model_schema.parameter_rules)
model_config_with_cred.parameters = completion_params
return model, model_config_with_cred

def _fetch_prompt_messages(
self,
@@ -810,55 +739,6 @@ class LLMNode(BaseNode[LLMNodeData]):
structured_output = parsed
return structured_output

@classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration

if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return

system_configuration = provider_configuration.system_configuration

quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit

if quota_configuration.quota_limit == -1:
return

break

used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model)
else:
used_quota = 1

if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
)
)
session.execute(stmt)
session.commit()

@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

+ 14
- 6
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py 查看文件

@@ -28,8 +28,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import LLMNode, ModelConfig
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser

from .entities import ParameterExtractorNodeData
@@ -83,7 +84,7 @@ def extract_json(text):
return None


class ParameterExtractorNode(LLMNode):
class ParameterExtractorNode(BaseNode):
"""
Parameter Extractor Node.
"""
@@ -116,8 +117,11 @@ class ParameterExtractorNode(LLMNode):
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""

variable_pool = self.graph_runtime_state.variable_pool

files = (
self._fetch_files(
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled
@@ -137,7 +141,9 @@ class ParameterExtractorNode(LLMNode):
raise ModelSchemaNotFoundError("Model schema not found")

# fetch memory
memory = self._fetch_memory(
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
@@ -279,7 +285,7 @@ class ParameterExtractorNode(LLMNode):
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None

# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)

if text is None:
text = ""
@@ -794,7 +800,9 @@ class ParameterExtractorNode(LLMNode):
Fetch model config.
"""
if not self._model_instance or not self._model_config:
self._model_instance, self._model_config = super()._fetch_model_config(node_data_model)
self._model_instance, self._model_config = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model
)

return self._model_instance, self._model_config


+ 6
- 2
api/core/workflow/nodes/question_classifier/question_classifier_node.py 查看文件

@@ -18,6 +18,7 @@ from core.workflow.nodes.llm import (
LLMNode,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
@@ -50,7 +51,9 @@ class QuestionClassifierNode(LLMNode):
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
@@ -59,7 +62,8 @@ class QuestionClassifierNode(LLMNode):
node_data.instruction = variable_pool.convert_template(node_data.instruction).text

files = (
self._fetch_files(
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled

+ 3
- 2
api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py 查看文件

@@ -353,7 +353,7 @@ def test_extract_json_from_tool_call():
assert result["location"] == "kawaii"


def test_chat_parameter_extractor_with_memory(setup_model_mock):
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
"""
Test chat parameter extractor with memory.
"""
@@ -384,7 +384,8 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock):
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
node._fetch_memory = get_mocked_fetch_memory("customized memory")
# Test the mock before running the actual test
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
db.session.close = MagicMock()

result = node._run()

+ 20
- 14
api/tests/unit_tests/core/workflow/nodes/llm/test_node.py 查看文件

@@ -25,6 +25,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
@@ -170,7 +171,7 @@ def model_config():
)


def test_fetch_files_with_file_segment(llm_node):
def test_fetch_files_with_file_segment():
file = File(
id="1",
tenant_id="test",
@@ -180,13 +181,14 @@ def test_fetch_files_with_file_segment(llm_node):
related_id="1",
storage_key="",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
variable_pool = VariablePool()
variable_pool.add(["sys", "files"], file)

result = llm_node._fetch_files(selector=["sys", "files"])
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == [file]


def test_fetch_files_with_array_file_segment(llm_node):
def test_fetch_files_with_array_file_segment():
files = [
File(
id="1",
@@ -207,28 +209,32 @@ def test_fetch_files_with_array_file_segment(llm_node):
storage_key="",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
variable_pool = VariablePool()
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))

result = llm_node._fetch_files(selector=["sys", "files"])
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == files


def test_fetch_files_with_none_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
def test_fetch_files_with_none_segment():
variable_pool = VariablePool()
variable_pool.add(["sys", "files"], NoneSegment())

result = llm_node._fetch_files(selector=["sys", "files"])
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []


def test_fetch_files_with_array_any_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
def test_fetch_files_with_array_any_segment():
variable_pool = VariablePool()
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))

result = llm_node._fetch_files(selector=["sys", "files"])
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []


def test_fetch_files_with_non_existent_variable(llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
def test_fetch_files_with_non_existent_variable():
variable_pool = VariablePool()
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []



正在加载...
取消
保存