|
|
|
@@ -0,0 +1,125 @@ |
|
|
|
import pytest |
|
|
|
|
|
|
|
from core.app.entities.app_invoke_entities import InvokeFrom |
|
|
|
from core.file import File, FileTransferMethod, FileType |
|
|
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent |
|
|
|
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment |
|
|
|
from core.workflow.entities.variable_pool import VariablePool |
|
|
|
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState |
|
|
|
from core.workflow.nodes.answer import AnswerStreamGenerateRoute |
|
|
|
from core.workflow.nodes.end import EndStreamParam |
|
|
|
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions |
|
|
|
from core.workflow.nodes.llm.node import LLMNode |
|
|
|
from models.enums import UserFrom |
|
|
|
from models.workflow import WorkflowType |
|
|
|
|
|
|
|
|
|
|
|
class TestLLMNode: |
|
|
|
@pytest.fixture |
|
|
|
def llm_node(self): |
|
|
|
data = LLMNodeData( |
|
|
|
title="Test LLM", |
|
|
|
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), |
|
|
|
prompt_template=[], |
|
|
|
memory=None, |
|
|
|
context=ContextConfig(enabled=False), |
|
|
|
vision=VisionConfig( |
|
|
|
enabled=True, |
|
|
|
configs=VisionConfigOptions( |
|
|
|
variable_selector=["sys", "files"], |
|
|
|
detail=ImagePromptMessageContent.DETAIL.HIGH, |
|
|
|
), |
|
|
|
), |
|
|
|
) |
|
|
|
variable_pool = VariablePool( |
|
|
|
system_variables={}, |
|
|
|
user_inputs={}, |
|
|
|
) |
|
|
|
node = LLMNode( |
|
|
|
id="1", |
|
|
|
config={ |
|
|
|
"id": "1", |
|
|
|
"data": data.model_dump(), |
|
|
|
}, |
|
|
|
graph_init_params=GraphInitParams( |
|
|
|
tenant_id="1", |
|
|
|
app_id="1", |
|
|
|
workflow_type=WorkflowType.WORKFLOW, |
|
|
|
workflow_id="1", |
|
|
|
graph_config={}, |
|
|
|
user_id="1", |
|
|
|
user_from=UserFrom.ACCOUNT, |
|
|
|
invoke_from=InvokeFrom.SERVICE_API, |
|
|
|
call_depth=0, |
|
|
|
), |
|
|
|
graph=Graph( |
|
|
|
root_node_id="1", |
|
|
|
answer_stream_generate_routes=AnswerStreamGenerateRoute( |
|
|
|
answer_dependencies={}, |
|
|
|
answer_generate_route={}, |
|
|
|
), |
|
|
|
end_stream_param=EndStreamParam( |
|
|
|
end_dependencies={}, |
|
|
|
end_stream_variable_selector_mapping={}, |
|
|
|
), |
|
|
|
), |
|
|
|
graph_runtime_state=GraphRuntimeState( |
|
|
|
variable_pool=variable_pool, |
|
|
|
start_at=0, |
|
|
|
), |
|
|
|
) |
|
|
|
return node |
|
|
|
|
|
|
|
def test_fetch_files_with_file_segment(self, llm_node): |
|
|
|
file = File( |
|
|
|
id="1", |
|
|
|
tenant_id="test", |
|
|
|
type=FileType.IMAGE, |
|
|
|
filename="test.jpg", |
|
|
|
transfer_method=FileTransferMethod.LOCAL_FILE, |
|
|
|
related_id="1", |
|
|
|
) |
|
|
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) |
|
|
|
|
|
|
|
result = llm_node._fetch_files(selector=["sys", "files"]) |
|
|
|
assert result == [file] |
|
|
|
|
|
|
|
def test_fetch_files_with_array_file_segment(self, llm_node): |
|
|
|
files = [ |
|
|
|
File( |
|
|
|
id="1", |
|
|
|
tenant_id="test", |
|
|
|
type=FileType.IMAGE, |
|
|
|
filename="test1.jpg", |
|
|
|
transfer_method=FileTransferMethod.LOCAL_FILE, |
|
|
|
related_id="1", |
|
|
|
), |
|
|
|
File( |
|
|
|
id="2", |
|
|
|
tenant_id="test", |
|
|
|
type=FileType.IMAGE, |
|
|
|
filename="test2.jpg", |
|
|
|
transfer_method=FileTransferMethod.LOCAL_FILE, |
|
|
|
related_id="2", |
|
|
|
), |
|
|
|
] |
|
|
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) |
|
|
|
|
|
|
|
result = llm_node._fetch_files(selector=["sys", "files"]) |
|
|
|
assert result == files |
|
|
|
|
|
|
|
def test_fetch_files_with_none_segment(self, llm_node): |
|
|
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) |
|
|
|
|
|
|
|
result = llm_node._fetch_files(selector=["sys", "files"]) |
|
|
|
assert result == [] |
|
|
|
|
|
|
|
def test_fetch_files_with_array_any_segment(self, llm_node): |
|
|
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) |
|
|
|
|
|
|
|
result = llm_node._fetch_files(selector=["sys", "files"]) |
|
|
|
assert result == [] |
|
|
|
|
|
|
|
def test_fetch_files_with_non_existent_variable(self, llm_node): |
|
|
|
result = llm_node._fetch_files(selector=["sys", "files"]) |
|
|
|
assert result == [] |