Signed-off-by: -LAN- <laipz8200@outlook.com>tags/0.14.2
| @@ -1,15 +1,14 @@ | |||
| import base64 | |||
| from configs import dify_config | |||
| from core.file import file_repository | |||
| from core.helper import ssrf_proxy | |||
| from core.model_runtime.entities import ( | |||
| AudioPromptMessageContent, | |||
| DocumentPromptMessageContent, | |||
| ImagePromptMessageContent, | |||
| MultiModalPromptMessageContent, | |||
| VideoPromptMessageContent, | |||
| ) | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from . import helpers | |||
| @@ -41,7 +40,7 @@ def to_prompt_message_content( | |||
| /, | |||
| *, | |||
| image_detail_config: ImagePromptMessageContent.DETAIL | None = None, | |||
| ): | |||
| ) -> MultiModalPromptMessageContent: | |||
| if f.extension is None: | |||
| raise ValueError("Missing file extension") | |||
| if f.mime_type is None: | |||
| @@ -70,16 +69,13 @@ def to_prompt_message_content( | |||
| def download(f: File, /): | |||
| if f.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| tool_file = file_repository.get_tool_file(session=db.session(), file=f) | |||
| return _download_file_content(tool_file.file_key) | |||
| elif f.transfer_method == FileTransferMethod.LOCAL_FILE: | |||
| upload_file = file_repository.get_upload_file(session=db.session(), file=f) | |||
| return _download_file_content(upload_file.key) | |||
| # remote file | |||
| response = ssrf_proxy.get(f.remote_url, follow_redirects=True) | |||
| response.raise_for_status() | |||
| return response.content | |||
| if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE): | |||
| return _download_file_content(f._storage_key) | |||
| elif f.transfer_method == FileTransferMethod.REMOTE_URL: | |||
| response = ssrf_proxy.get(f.remote_url, follow_redirects=True) | |||
| response.raise_for_status() | |||
| return response.content | |||
| raise ValueError(f"unsupported transfer method: {f.transfer_method}") | |||
| def _download_file_content(path: str, /): | |||
| @@ -110,11 +106,9 @@ def _get_encoded_string(f: File, /): | |||
| response.raise_for_status() | |||
| data = response.content | |||
| case FileTransferMethod.LOCAL_FILE: | |||
| upload_file = file_repository.get_upload_file(session=db.session(), file=f) | |||
| data = _download_file_content(upload_file.key) | |||
| data = _download_file_content(f._storage_key) | |||
| case FileTransferMethod.TOOL_FILE: | |||
| tool_file = file_repository.get_tool_file(session=db.session(), file=f) | |||
| data = _download_file_content(tool_file.file_key) | |||
| data = _download_file_content(f._storage_key) | |||
| encoded_string = base64.b64encode(data).decode("utf-8") | |||
| return encoded_string | |||
| @@ -1,32 +0,0 @@ | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from models import ToolFile, UploadFile | |||
| from .models import File | |||
| def get_upload_file(*, session: Session, file: File): | |||
| if file.related_id is None: | |||
| raise ValueError("Missing file related_id") | |||
| stmt = select(UploadFile).filter( | |||
| UploadFile.id == file.related_id, | |||
| UploadFile.tenant_id == file.tenant_id, | |||
| ) | |||
| record = session.scalar(stmt) | |||
| if not record: | |||
| raise ValueError(f"upload file {file.related_id} not found") | |||
| return record | |||
| def get_tool_file(*, session: Session, file: File): | |||
| if file.related_id is None: | |||
| raise ValueError("Missing file related_id") | |||
| stmt = select(ToolFile).filter( | |||
| ToolFile.id == file.related_id, | |||
| ToolFile.tenant_id == file.tenant_id, | |||
| ) | |||
| record = session.scalar(stmt) | |||
| if not record: | |||
| raise ValueError(f"tool file {file.related_id} not found") | |||
| return record | |||
| @@ -47,6 +47,38 @@ class File(BaseModel): | |||
| mime_type: Optional[str] = None | |||
| size: int = -1 | |||
| # Those properties are private, should not be exposed to the outside. | |||
| _storage_key: str | |||
| def __init__( | |||
| self, | |||
| *, | |||
| id: Optional[str] = None, | |||
| tenant_id: str, | |||
| type: FileType, | |||
| transfer_method: FileTransferMethod, | |||
| remote_url: Optional[str] = None, | |||
| related_id: Optional[str] = None, | |||
| filename: Optional[str] = None, | |||
| extension: Optional[str] = None, | |||
| mime_type: Optional[str] = None, | |||
| size: int = -1, | |||
| storage_key: str, | |||
| ): | |||
| super().__init__( | |||
| id=id, | |||
| tenant_id=tenant_id, | |||
| type=type, | |||
| transfer_method=transfer_method, | |||
| remote_url=remote_url, | |||
| related_id=related_id, | |||
| filename=filename, | |||
| extension=extension, | |||
| mime_type=mime_type, | |||
| size=size, | |||
| ) | |||
| self._storage_key = storage_key | |||
| def to_dict(self) -> Mapping[str, str | int | None]: | |||
| data = self.model_dump(mode="json") | |||
| return { | |||
| @@ -4,6 +4,7 @@ from .message_entities import ( | |||
| AudioPromptMessageContent, | |||
| DocumentPromptMessageContent, | |||
| ImagePromptMessageContent, | |||
| MultiModalPromptMessageContent, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageContentType, | |||
| @@ -27,6 +28,7 @@ __all__ = [ | |||
| "LLMResultChunkDelta", | |||
| "LLMUsage", | |||
| "ModelPropertyKey", | |||
| "MultiModalPromptMessageContent", | |||
| "PromptMessage", | |||
| "PromptMessage", | |||
| "PromptMessageContent", | |||
| @@ -84,10 +84,10 @@ class MultiModalPromptMessageContent(PromptMessageContent): | |||
| """ | |||
| type: PromptMessageContentType | |||
| format: str = Field(..., description="the format of multi-modal file") | |||
| base64_data: str = Field("", description="the base64 data of multi-modal file") | |||
| url: str = Field("", description="the url of multi-modal file") | |||
| mime_type: str = Field(..., description="the mime type of multi-modal file") | |||
| format: str = Field(default=..., description="the format of multi-modal file") | |||
| base64_data: str = Field(default="", description="the base64 data of multi-modal file") | |||
| url: str = Field(default="", description="the url of multi-modal file") | |||
| mime_type: str = Field(default=..., description="the mime type of multi-modal file") | |||
| @computed_field(return_type=str) | |||
| @property | |||
| @@ -50,6 +50,7 @@ class PromptConfig(BaseModel): | |||
| class LLMNodeChatModelMessage(ChatModelMessage): | |||
| text: str = "" | |||
| jinja2_text: Optional[str] = None | |||
| @@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| query = query_variable.text | |||
| prompt_messages, stop = self._fetch_prompt_messages( | |||
| user_query=query, | |||
| user_files=files, | |||
| sys_query=query, | |||
| sys_files=files, | |||
| context=context, | |||
| memory=memory, | |||
| model_config=model_config, | |||
| @@ -545,8 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| def _fetch_prompt_messages( | |||
| self, | |||
| *, | |||
| user_query: str | None = None, | |||
| user_files: Sequence["File"], | |||
| sys_query: str | None = None, | |||
| sys_files: Sequence["File"], | |||
| context: str | None = None, | |||
| memory: TokenBufferMemory | None = None, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| @@ -562,7 +562,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| if isinstance(prompt_template, list): | |||
| # For chat model | |||
| prompt_messages.extend( | |||
| _handle_list_messages( | |||
| self._handle_list_messages( | |||
| messages=prompt_template, | |||
| context=context, | |||
| jinja2_variables=jinja2_variables, | |||
| @@ -581,14 +581,14 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| prompt_messages.extend(memory_messages) | |||
| # Add current query to the prompt messages | |||
| if user_query: | |||
| if sys_query: | |||
| message = LLMNodeChatModelMessage( | |||
| text=user_query, | |||
| text=sys_query, | |||
| role=PromptMessageRole.USER, | |||
| edition_type="basic", | |||
| ) | |||
| prompt_messages.extend( | |||
| _handle_list_messages( | |||
| self._handle_list_messages( | |||
| messages=[message], | |||
| context="", | |||
| jinja2_variables=[], | |||
| @@ -635,24 +635,27 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| raise ValueError("Invalid prompt content type") | |||
| # Add current query to the prompt message | |||
| if user_query: | |||
| if sys_query: | |||
| if prompt_content_type == str: | |||
| prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query) | |||
| prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query) | |||
| prompt_messages[0].content = prompt_content | |||
| elif prompt_content_type == list: | |||
| for content_item in prompt_content: | |||
| if content_item.type == PromptMessageContentType.TEXT: | |||
| content_item.data = user_query + "\n" + content_item.data | |||
| content_item.data = sys_query + "\n" + content_item.data | |||
| else: | |||
| raise ValueError("Invalid prompt content type") | |||
| else: | |||
| raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) | |||
| if vision_enabled and user_files: | |||
| # The sys_files will be deprecated later | |||
| if vision_enabled and sys_files: | |||
| file_prompts = [] | |||
| for file in user_files: | |||
| for file in sys_files: | |||
| file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) | |||
| file_prompts.append(file_prompt) | |||
| # If last prompt is a user prompt, add files into its contents, | |||
| # otherwise append a new user prompt | |||
| if ( | |||
| len(prompt_messages) > 0 | |||
| and isinstance(prompt_messages[-1], UserPromptMessage) | |||
| @@ -662,7 +665,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| else: | |||
| prompt_messages.append(UserPromptMessage(content=file_prompts)) | |||
| # Filter prompt messages | |||
| # Remove empty messages and filter unsupported content | |||
| filtered_prompt_messages = [] | |||
| for prompt_message in prompt_messages: | |||
| if isinstance(prompt_message.content, list): | |||
| @@ -846,6 +849,58 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| }, | |||
| } | |||
| def _handle_list_messages( | |||
| self, | |||
| *, | |||
| messages: Sequence[LLMNodeChatModelMessage], | |||
| context: Optional[str], | |||
| jinja2_variables: Sequence[VariableSelector], | |||
| variable_pool: VariablePool, | |||
| vision_detail_config: ImagePromptMessageContent.DETAIL, | |||
| ) -> Sequence[PromptMessage]: | |||
| prompt_messages: list[PromptMessage] = [] | |||
| for message in messages: | |||
| contents: list[PromptMessageContent] = [] | |||
| if message.edition_type == "jinja2": | |||
| result_text = _render_jinja2_message( | |||
| template=message.jinja2_text or "", | |||
| jinjia2_variables=jinja2_variables, | |||
| variable_pool=variable_pool, | |||
| ) | |||
| contents.append(TextPromptMessageContent(data=result_text)) | |||
| else: | |||
| # Get segment group from basic message | |||
| if context: | |||
| template = message.text.replace("{#context#}", context) | |||
| else: | |||
| template = message.text | |||
| segment_group = variable_pool.convert_template(template) | |||
| # Process segments for images | |||
| for segment in segment_group.value: | |||
| if isinstance(segment, ArrayFileSegment): | |||
| for file in segment.value: | |||
| if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: | |||
| file_content = file_manager.to_prompt_message_content( | |||
| file, image_detail_config=vision_detail_config | |||
| ) | |||
| contents.append(file_content) | |||
| elif isinstance(segment, FileSegment): | |||
| file = segment.value | |||
| if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: | |||
| file_content = file_manager.to_prompt_message_content( | |||
| file, image_detail_config=vision_detail_config | |||
| ) | |||
| contents.append(file_content) | |||
| else: | |||
| plain_text = segment.markdown.strip() | |||
| if plain_text: | |||
| contents.append(TextPromptMessageContent(data=plain_text)) | |||
| prompt_message = _combine_message_content_with_role(contents=contents, role=message.role) | |||
| prompt_messages.append(prompt_message) | |||
| return prompt_messages | |||
| def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): | |||
| match role: | |||
| @@ -880,68 +935,6 @@ def _render_jinja2_message( | |||
| return result_text | |||
| def _handle_list_messages( | |||
| *, | |||
| messages: Sequence[LLMNodeChatModelMessage], | |||
| context: Optional[str], | |||
| jinja2_variables: Sequence[VariableSelector], | |||
| variable_pool: VariablePool, | |||
| vision_detail_config: ImagePromptMessageContent.DETAIL, | |||
| ) -> Sequence[PromptMessage]: | |||
| prompt_messages = [] | |||
| for message in messages: | |||
| if message.edition_type == "jinja2": | |||
| result_text = _render_jinja2_message( | |||
| template=message.jinja2_text or "", | |||
| jinjia2_variables=jinja2_variables, | |||
| variable_pool=variable_pool, | |||
| ) | |||
| prompt_message = _combine_message_content_with_role( | |||
| contents=[TextPromptMessageContent(data=result_text)], role=message.role | |||
| ) | |||
| prompt_messages.append(prompt_message) | |||
| else: | |||
| # Get segment group from basic message | |||
| if context: | |||
| template = message.text.replace("{#context#}", context) | |||
| else: | |||
| template = message.text | |||
| segment_group = variable_pool.convert_template(template) | |||
| # Process segments for images | |||
| file_contents = [] | |||
| for segment in segment_group.value: | |||
| if isinstance(segment, ArrayFileSegment): | |||
| for file in segment.value: | |||
| if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: | |||
| file_content = file_manager.to_prompt_message_content( | |||
| file, image_detail_config=vision_detail_config | |||
| ) | |||
| file_contents.append(file_content) | |||
| if isinstance(segment, FileSegment): | |||
| file = segment.value | |||
| if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: | |||
| file_content = file_manager.to_prompt_message_content( | |||
| file, image_detail_config=vision_detail_config | |||
| ) | |||
| file_contents.append(file_content) | |||
| # Create message with text from all segments | |||
| plain_text = segment_group.text | |||
| if plain_text: | |||
| prompt_message = _combine_message_content_with_role( | |||
| contents=[TextPromptMessageContent(data=plain_text)], role=message.role | |||
| ) | |||
| prompt_messages.append(prompt_message) | |||
| if file_contents: | |||
| # Create message with image contents | |||
| prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) | |||
| prompt_messages.append(prompt_message) | |||
| return prompt_messages | |||
| def _calculate_rest_token( | |||
| *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity | |||
| ) -> int: | |||
| @@ -86,10 +86,10 @@ class QuestionClassifierNode(LLMNode): | |||
| ) | |||
| prompt_messages, stop = self._fetch_prompt_messages( | |||
| prompt_template=prompt_template, | |||
| user_query=query, | |||
| sys_query=query, | |||
| memory=memory, | |||
| model_config=model_config, | |||
| user_files=files, | |||
| sys_files=files, | |||
| vision_enabled=node_data.vision.enabled, | |||
| vision_detail=node_data.vision.configs.detail, | |||
| variable_pool=variable_pool, | |||
| @@ -139,6 +139,7 @@ def _build_from_local_file( | |||
| remote_url=row.source_url, | |||
| related_id=mapping.get("upload_file_id"), | |||
| size=row.size, | |||
| storage_key=row.key, | |||
| ) | |||
| @@ -168,6 +169,7 @@ def _build_from_remote_url( | |||
| mime_type=mime_type, | |||
| extension=extension, | |||
| size=file_size, | |||
| storage_key="", | |||
| ) | |||
| @@ -220,6 +222,7 @@ def _build_from_tool_file( | |||
| extension=extension, | |||
| mime_type=tool_file.mimetype, | |||
| size=tool_file.size, | |||
| storage_key=tool_file.file_key, | |||
| ) | |||
| @@ -560,13 +560,29 @@ class Conversation(db.Model): | |||
| @property | |||
| def inputs(self): | |||
| inputs = self._inputs.copy() | |||
| # Convert file mapping to File object | |||
| for key, value in inputs.items(): | |||
| # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | |||
| from factories import file_factory | |||
| if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||
| inputs[key] = File.model_validate(value) | |||
| if value["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| value["tool_file_id"] = value["related_id"] | |||
| elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE: | |||
| value["upload_file_id"] = value["related_id"] | |||
| inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) | |||
| elif isinstance(value, list) and all( | |||
| isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value | |||
| ): | |||
| inputs[key] = [File.model_validate(item) for item in value] | |||
| inputs[key] = [] | |||
| for item in value: | |||
| if item["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| item["tool_file_id"] = item["related_id"] | |||
| elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE: | |||
| item["upload_file_id"] = item["related_id"] | |||
| inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) | |||
| return inputs | |||
| @inputs.setter | |||
| @@ -758,12 +774,25 @@ class Message(db.Model): | |||
| def inputs(self): | |||
| inputs = self._inputs.copy() | |||
| for key, value in inputs.items(): | |||
| # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | |||
| from factories import file_factory | |||
| if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||
| inputs[key] = File.model_validate(value) | |||
| if value["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| value["tool_file_id"] = value["related_id"] | |||
| elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE: | |||
| value["upload_file_id"] = value["related_id"] | |||
| inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) | |||
| elif isinstance(value, list) and all( | |||
| isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value | |||
| ): | |||
| inputs[key] = [File.model_validate(item) for item in value] | |||
| inputs[key] = [] | |||
| for item in value: | |||
| if item["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| item["tool_file_id"] = item["related_id"] | |||
| elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE: | |||
| item["upload_file_id"] = item["related_id"] | |||
| inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) | |||
| return inputs | |||
| @inputs.setter | |||
| @@ -136,6 +136,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg | |||
| type=FileType.IMAGE, | |||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||
| remote_url="https://example.com/image1.jpg", | |||
| storage_key="", | |||
| ) | |||
| ] | |||
| @@ -1,34 +1,9 @@ | |||
| import json | |||
| from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig | |||
| from core.file import File, FileTransferMethod, FileType, FileUploadConfig | |||
| from models.workflow import Workflow | |||
| def test_file_loads_and_dumps(): | |||
| file = File( | |||
| id="file1", | |||
| tenant_id="tenant1", | |||
| type=FileType.IMAGE, | |||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||
| remote_url="https://example.com/image1.jpg", | |||
| ) | |||
| file_dict = file.model_dump() | |||
| assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY | |||
| assert file_dict["type"] == file.type.value | |||
| assert isinstance(file_dict["type"], str) | |||
| assert file_dict["transfer_method"] == file.transfer_method.value | |||
| assert isinstance(file_dict["transfer_method"], str) | |||
| assert "_extra_config" not in file_dict | |||
| file_obj = File.model_validate(file_dict) | |||
| assert file_obj.id == file.id | |||
| assert file_obj.tenant_id == file.tenant_id | |||
| assert file_obj.type == file.type | |||
| assert file_obj.transfer_method == file.transfer_method | |||
| assert file_obj.remote_url == file.remote_url | |||
| def test_file_to_dict(): | |||
| file = File( | |||
| id="file1", | |||
| @@ -36,10 +11,11 @@ def test_file_to_dict(): | |||
| type=FileType.IMAGE, | |||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||
| remote_url="https://example.com/image1.jpg", | |||
| storage_key="storage_key", | |||
| ) | |||
| file_dict = file.to_dict() | |||
| assert "_extra_config" not in file_dict | |||
| assert "_storage_key" not in file_dict | |||
| assert "url" in file_dict | |||
| @@ -51,6 +51,7 @@ def test_http_request_node_binary_file(monkeypatch): | |||
| type=FileType.IMAGE, | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="1111", | |||
| storage_key="", | |||
| ), | |||
| ), | |||
| ) | |||
| @@ -138,6 +139,7 @@ def test_http_request_node_form_with_file(monkeypatch): | |||
| type=FileType.IMAGE, | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="1111", | |||
| storage_key="", | |||
| ), | |||
| ), | |||
| ) | |||
| @@ -21,7 +21,8 @@ from core.model_runtime.entities.message_entities import ( | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType | |||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.prompt.entities.advanced_prompt_entities import MemoryConfig | |||
| from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment | |||
| from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment, StringSegment | |||
| from core.workflow.entities.variable_entities import VariableSelector | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState | |||
| from core.workflow.nodes.answer import AnswerStreamGenerateRoute | |||
| @@ -157,6 +158,7 @@ def test_fetch_files_with_file_segment(llm_node): | |||
| filename="test.jpg", | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="1", | |||
| storage_key="", | |||
| ) | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) | |||
| @@ -173,6 +175,7 @@ def test_fetch_files_with_array_file_segment(llm_node): | |||
| filename="test1.jpg", | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="1", | |||
| storage_key="", | |||
| ), | |||
| File( | |||
| id="2", | |||
| @@ -181,6 +184,7 @@ def test_fetch_files_with_array_file_segment(llm_node): | |||
| filename="test2.jpg", | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="2", | |||
| storage_key="", | |||
| ), | |||
| ] | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) | |||
| @@ -224,14 +228,15 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): | |||
| filename="test1.jpg", | |||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||
| remote_url=fake_remote_url, | |||
| storage_key="", | |||
| ) | |||
| ] | |||
| fake_query = faker.sentence() | |||
| prompt_messages, _ = llm_node._fetch_prompt_messages( | |||
| user_query=fake_query, | |||
| user_files=files, | |||
| sys_query=fake_query, | |||
| sys_files=files, | |||
| context=None, | |||
| memory=None, | |||
| model_config=model_config, | |||
| @@ -283,8 +288,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): | |||
| test_scenarios = [ | |||
| LLMNodeTestScenario( | |||
| description="No files", | |||
| user_query=fake_query, | |||
| user_files=[], | |||
| sys_query=fake_query, | |||
| sys_files=[], | |||
| features=[], | |||
| vision_enabled=False, | |||
| vision_detail=None, | |||
| @@ -318,8 +323,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): | |||
| ), | |||
| LLMNodeTestScenario( | |||
| description="User files", | |||
| user_query=fake_query, | |||
| user_files=[ | |||
| sys_query=fake_query, | |||
| sys_files=[ | |||
| File( | |||
| tenant_id="test", | |||
| type=FileType.IMAGE, | |||
| @@ -328,6 +333,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): | |||
| remote_url=fake_remote_url, | |||
| extension=".jpg", | |||
| mime_type="image/jpg", | |||
| storage_key="", | |||
| ) | |||
| ], | |||
| vision_enabled=True, | |||
| @@ -370,8 +376,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): | |||
| ), | |||
| LLMNodeTestScenario( | |||
| description="Prompt template with variable selector of File", | |||
| user_query=fake_query, | |||
| user_files=[], | |||
| sys_query=fake_query, | |||
| sys_files=[], | |||
| vision_enabled=False, | |||
| vision_detail=fake_vision_detail, | |||
| features=[ModelFeature.VISION], | |||
| @@ -403,6 +409,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): | |||
| remote_url=fake_remote_url, | |||
| extension=".jpg", | |||
| mime_type="image/jpg", | |||
| storage_key="", | |||
| ) | |||
| }, | |||
| ), | |||
| @@ -417,8 +424,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): | |||
| # Call the method under test | |||
| prompt_messages, _ = llm_node._fetch_prompt_messages( | |||
| user_query=scenario.user_query, | |||
| user_files=scenario.user_files, | |||
| sys_query=scenario.sys_query, | |||
| sys_files=scenario.sys_files, | |||
| context=fake_context, | |||
| memory=memory, | |||
| model_config=model_config, | |||
| @@ -435,3 +442,29 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): | |||
| assert ( | |||
| prompt_messages == scenario.expected_messages | |||
| ), f"Message content mismatch in scenario: {scenario.description}" | |||
| def test_handle_list_messages_basic(llm_node): | |||
| messages = [ | |||
| LLMNodeChatModelMessage( | |||
| text="Hello, {#context#}", | |||
| role=PromptMessageRole.USER, | |||
| edition_type="basic", | |||
| ) | |||
| ] | |||
| context = "world" | |||
| jinja2_variables = [] | |||
| variable_pool = llm_node.graph_runtime_state.variable_pool | |||
| vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH | |||
| result = llm_node._handle_list_messages( | |||
| messages=messages, | |||
| context=context, | |||
| jinja2_variables=jinja2_variables, | |||
| variable_pool=variable_pool, | |||
| vision_detail_config=vision_detail_config, | |||
| ) | |||
| assert len(result) == 1 | |||
| assert isinstance(result[0], UserPromptMessage) | |||
| assert result[0].content == [TextPromptMessageContent(data="Hello, world")] | |||
| @@ -12,8 +12,8 @@ class LLMNodeTestScenario(BaseModel): | |||
| """Test scenario for LLM node testing.""" | |||
| description: str = Field(..., description="Description of the test scenario") | |||
| user_query: str = Field(..., description="User query input") | |||
| user_files: Sequence[File] = Field(default_factory=list, description="List of user files") | |||
| sys_query: str = Field(..., description="User query input") | |||
| sys_files: Sequence[File] = Field(default_factory=list, description="List of user files") | |||
| vision_enabled: bool = Field(default=False, description="Whether vision is enabled") | |||
| vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") | |||
| features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") | |||
| @@ -248,6 +248,7 @@ def test_array_file_contains_file_name(): | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="1", | |||
| filename="ab", | |||
| storage_key="", | |||
| ), | |||
| ], | |||
| ) | |||
| @@ -57,6 +57,7 @@ def test_filter_files_by_type(list_operator_node): | |||
| tenant_id="tenant1", | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="related1", | |||
| storage_key="", | |||
| ), | |||
| File( | |||
| filename="document1.pdf", | |||
| @@ -64,6 +65,7 @@ def test_filter_files_by_type(list_operator_node): | |||
| tenant_id="tenant1", | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="related2", | |||
| storage_key="", | |||
| ), | |||
| File( | |||
| filename="image2.png", | |||
| @@ -71,6 +73,7 @@ def test_filter_files_by_type(list_operator_node): | |||
| tenant_id="tenant1", | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="related3", | |||
| storage_key="", | |||
| ), | |||
| File( | |||
| filename="audio1.mp3", | |||
| @@ -78,6 +81,7 @@ def test_filter_files_by_type(list_operator_node): | |||
| tenant_id="tenant1", | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="related4", | |||
| storage_key="", | |||
| ), | |||
| ] | |||
| variable = ArrayFileSegment(value=files) | |||
| @@ -130,6 +134,7 @@ def test_get_file_extract_string_func(): | |||
| mime_type="text/plain", | |||
| remote_url="https://example.com/test_file.txt", | |||
| related_id="test_related_id", | |||
| storage_key="", | |||
| ) | |||
| # Test each case | |||
| @@ -150,6 +155,7 @@ def test_get_file_extract_string_func(): | |||
| mime_type=None, | |||
| remote_url=None, | |||
| related_id="test_related_id", | |||
| storage_key="", | |||
| ) | |||
| assert _get_file_extract_string_func(key="name")(empty_file) == "" | |||
| @@ -19,6 +19,7 @@ def file(): | |||
| related_id="test_related_id", | |||
| remote_url="test_url", | |||
| filename="test_file.txt", | |||
| storage_key="", | |||
| ) | |||