Co-authored-by: Joel <iamjoel007@gmail.com>tags/0.12.0
| @@ -27,7 +27,6 @@ class DifyConfig( | |||
| # read from dotenv format config file | |||
| env_file=".env", | |||
| env_file_encoding="utf-8", | |||
| frozen=True, | |||
| # ignore extra attributes | |||
| extra="ignore", | |||
| ) | |||
| @@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager | |||
| class ModelConfigConverter: | |||
| @classmethod | |||
| def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: | |||
| def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity: | |||
| """ | |||
| Convert app model config dict to entity. | |||
| :param app_config: app config | |||
| @@ -38,27 +38,23 @@ class ModelConfigConverter: | |||
| ) | |||
| if model_credentials is None: | |||
| if not skip_check: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") | |||
| else: | |||
| model_credentials = {} | |||
| if not skip_check: | |||
| # check model | |||
| provider_model = provider_model_bundle.configuration.get_provider_model( | |||
| model=model_config.model, model_type=ModelType.LLM | |||
| ) | |||
| if provider_model is None: | |||
| model_name = model_config.model | |||
| raise ValueError(f"Model {model_name} not exist.") | |||
| if provider_model.status == ModelStatus.NO_CONFIGURE: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") | |||
| elif provider_model.status == ModelStatus.NO_PERMISSION: | |||
| raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") | |||
| elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: | |||
| raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") | |||
| raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") | |||
| # check model | |||
| provider_model = provider_model_bundle.configuration.get_provider_model( | |||
| model=model_config.model, model_type=ModelType.LLM | |||
| ) | |||
| if provider_model is None: | |||
| model_name = model_config.model | |||
| raise ValueError(f"Model {model_name} not exist.") | |||
| if provider_model.status == ModelStatus.NO_CONFIGURE: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") | |||
| elif provider_model.status == ModelStatus.NO_PERMISSION: | |||
| raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") | |||
| elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: | |||
| raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") | |||
| # model config | |||
| completion_params = model_config.parameters | |||
| @@ -76,7 +72,7 @@ class ModelConfigConverter: | |||
| model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) | |||
| if not skip_check and not model_schema: | |||
| if not model_schema: | |||
| raise ValueError(f"Model {model_name} not exist.") | |||
| return ModelConfigWithCredentialsEntity( | |||
| @@ -217,9 +217,12 @@ class WorkflowCycleManage: | |||
| ).total_seconds() | |||
| db.session.commit() | |||
| db.session.refresh(workflow_run) | |||
| db.session.close() | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| session.add(workflow_run) | |||
| session.refresh(workflow_run) | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| TraceTask( | |||
| @@ -3,7 +3,12 @@ import base64 | |||
| from configs import dify_config | |||
| from core.file import file_repository | |||
| from core.helper import ssrf_proxy | |||
| from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent | |||
| from core.model_runtime.entities import ( | |||
| AudioPromptMessageContent, | |||
| DocumentPromptMessageContent, | |||
| ImagePromptMessageContent, | |||
| VideoPromptMessageContent, | |||
| ) | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| @@ -29,35 +34,17 @@ def get_attr(*, file: File, attr: FileAttribute): | |||
| return file.remote_url | |||
| case FileAttribute.EXTENSION: | |||
| return file.extension | |||
| case _: | |||
| raise ValueError(f"Invalid file attribute: {attr}") | |||
| def to_prompt_message_content( | |||
| f: File, | |||
| /, | |||
| *, | |||
| image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, | |||
| image_detail_config: ImagePromptMessageContent.DETAIL | None = None, | |||
| ): | |||
| """ | |||
| Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object. | |||
| This function takes a File object and converts it to an appropriate PromptMessageContent | |||
| object, which can be used as a prompt for image or audio-based AI models. | |||
| Args: | |||
| f (File): The File object to convert. | |||
| detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts. | |||
| If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW. | |||
| Returns: | |||
| Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level | |||
| Raises: | |||
| ValueError: If the file type is not supported or if required data is missing. | |||
| """ | |||
| match f.type: | |||
| case FileType.IMAGE: | |||
| image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW | |||
| if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": | |||
| data = _to_url(f) | |||
| else: | |||
| @@ -65,7 +52,7 @@ def to_prompt_message_content( | |||
| return ImagePromptMessageContent(data=data, detail=image_detail_config) | |||
| case FileType.AUDIO: | |||
| encoded_string = _file_to_encoded_string(f) | |||
| encoded_string = _get_encoded_string(f) | |||
| if f.extension is None: | |||
| raise ValueError("Missing file extension") | |||
| return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) | |||
| @@ -74,9 +61,20 @@ def to_prompt_message_content( | |||
| data = _to_url(f) | |||
| else: | |||
| data = _to_base64_data_string(f) | |||
| if f.extension is None: | |||
| raise ValueError("Missing file extension") | |||
| return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) | |||
| case FileType.DOCUMENT: | |||
| data = _get_encoded_string(f) | |||
| if f.mime_type is None: | |||
| raise ValueError("Missing file mime_type") | |||
| return DocumentPromptMessageContent( | |||
| encode_format="base64", | |||
| mime_type=f.mime_type, | |||
| data=data, | |||
| ) | |||
| case _: | |||
| raise ValueError("file type f.type is not supported") | |||
| raise ValueError(f"file type {f.type} is not supported") | |||
| def download(f: File, /): | |||
| @@ -118,21 +116,16 @@ def _get_encoded_string(f: File, /): | |||
| case FileTransferMethod.REMOTE_URL: | |||
| response = ssrf_proxy.get(f.remote_url, follow_redirects=True) | |||
| response.raise_for_status() | |||
| content = response.content | |||
| encoded_string = base64.b64encode(content).decode("utf-8") | |||
| return encoded_string | |||
| data = response.content | |||
| case FileTransferMethod.LOCAL_FILE: | |||
| upload_file = file_repository.get_upload_file(session=db.session(), file=f) | |||
| data = _download_file_content(upload_file.key) | |||
| encoded_string = base64.b64encode(data).decode("utf-8") | |||
| return encoded_string | |||
| case FileTransferMethod.TOOL_FILE: | |||
| tool_file = file_repository.get_tool_file(session=db.session(), file=f) | |||
| data = _download_file_content(tool_file.file_key) | |||
| encoded_string = base64.b64encode(data).decode("utf-8") | |||
| return encoded_string | |||
| case _: | |||
| raise ValueError(f"Unsupported transfer method: {f.transfer_method}") | |||
| encoded_string = base64.b64encode(data).decode("utf-8") | |||
| return encoded_string | |||
| def _to_base64_data_string(f: File, /): | |||
| @@ -140,18 +133,6 @@ def _to_base64_data_string(f: File, /): | |||
| return f"data:{f.mime_type};base64,{encoded_string}" | |||
| def _file_to_encoded_string(f: File, /): | |||
| match f.type: | |||
| case FileType.IMAGE: | |||
| return _to_base64_data_string(f) | |||
| case FileType.VIDEO: | |||
| return _to_base64_data_string(f) | |||
| case FileType.AUDIO: | |||
| return _get_encoded_string(f) | |||
| case _: | |||
| raise ValueError(f"file type {f.type} is not supported") | |||
| def _to_url(f: File, /): | |||
| if f.transfer_method == FileTransferMethod.REMOTE_URL: | |||
| if f.remote_url is None: | |||
| @@ -1,3 +1,4 @@ | |||
| from collections.abc import Sequence | |||
| from typing import Optional | |||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | |||
| @@ -27,7 +28,7 @@ class TokenBufferMemory: | |||
| def get_history_prompt_messages( | |||
| self, max_token_limit: int = 2000, message_limit: Optional[int] = None | |||
| ) -> list[PromptMessage]: | |||
| ) -> Sequence[PromptMessage]: | |||
| """ | |||
| Get history prompt messages. | |||
| :param max_token_limit: max token limit | |||
| @@ -100,10 +100,10 @@ class ModelInstance: | |||
| def invoke_llm( | |||
| self, | |||
| prompt_messages: list[PromptMessage], | |||
| prompt_messages: Sequence[PromptMessage], | |||
| model_parameters: Optional[dict] = None, | |||
| tools: Sequence[PromptMessageTool] | None = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| @@ -1,4 +1,5 @@ | |||
| from abc import ABC, abstractmethod | |||
| from collections.abc import Sequence | |||
| from typing import Optional | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||
| @@ -31,7 +32,7 @@ class Callback(ABC): | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ) -> None: | |||
| @@ -60,7 +61,7 @@ class Callback(ABC): | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ): | |||
| @@ -90,7 +91,7 @@ class Callback(ABC): | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ) -> None: | |||
| @@ -120,7 +121,7 @@ class Callback(ABC): | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ) -> None: | |||
| @@ -2,6 +2,7 @@ from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsa | |||
| from .message_entities import ( | |||
| AssistantPromptMessage, | |||
| AudioPromptMessageContent, | |||
| DocumentPromptMessageContent, | |||
| ImagePromptMessageContent, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| @@ -37,4 +38,5 @@ __all__ = [ | |||
| "LLMResultChunk", | |||
| "LLMResultChunkDelta", | |||
| "AudioPromptMessageContent", | |||
| "DocumentPromptMessageContent", | |||
| ] | |||
| @@ -1,6 +1,7 @@ | |||
| from abc import ABC | |||
| from collections.abc import Sequence | |||
| from enum import Enum | |||
| from typing import Optional | |||
| from typing import Literal, Optional | |||
| from pydantic import BaseModel, Field, field_validator | |||
| @@ -57,6 +58,7 @@ class PromptMessageContentType(Enum): | |||
| IMAGE = "image" | |||
| AUDIO = "audio" | |||
| VIDEO = "video" | |||
| DOCUMENT = "document" | |||
| class PromptMessageContent(BaseModel): | |||
| @@ -101,13 +103,20 @@ class ImagePromptMessageContent(PromptMessageContent): | |||
| detail: DETAIL = DETAIL.LOW | |||
| class DocumentPromptMessageContent(PromptMessageContent): | |||
| type: PromptMessageContentType = PromptMessageContentType.DOCUMENT | |||
| encode_format: Literal["base64"] | |||
| mime_type: str | |||
| data: str | |||
| class PromptMessage(ABC, BaseModel): | |||
| """ | |||
| Model class for prompt message. | |||
| """ | |||
| role: PromptMessageRole | |||
| content: Optional[str | list[PromptMessageContent]] = None | |||
| content: Optional[str | Sequence[PromptMessageContent]] = None | |||
| name: Optional[str] = None | |||
| def is_empty(self) -> bool: | |||
| @@ -87,6 +87,9 @@ class ModelFeature(Enum): | |||
| AGENT_THOUGHT = "agent-thought" | |||
| VISION = "vision" | |||
| STREAM_TOOL_CALL = "stream-tool-call" | |||
| DOCUMENT = "document" | |||
| VIDEO = "video" | |||
| AUDIO = "audio" | |||
| class DefaultParameterName(str, Enum): | |||
| @@ -2,7 +2,7 @@ import logging | |||
| import re | |||
| import time | |||
| from abc import abstractmethod | |||
| from collections.abc import Generator, Mapping | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import Optional, Union | |||
| from pydantic import ConfigDict | |||
| @@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel): | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: Optional[dict] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| @@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel): | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| @@ -212,7 +212,7 @@ if you are not sure about the structure. | |||
| ) | |||
| model_parameters.pop("response_format") | |||
| stop = stop or [] | |||
| stop = list(stop) if stop is not None else [] | |||
| stop.extend(["\n```", "```\n"]) | |||
| block_prompts = block_prompts.replace("{{block}}", code_block) | |||
| @@ -408,7 +408,7 @@ if you are not sure about the structure. | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| @@ -479,7 +479,7 @@ if you are not sure about the structure. | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ) -> Union[LLMResult, Generator]: | |||
| @@ -601,7 +601,7 @@ if you are not sure about the structure. | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| @@ -647,7 +647,7 @@ if you are not sure about the structure. | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| @@ -694,7 +694,7 @@ if you are not sure about the structure. | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| @@ -742,7 +742,7 @@ if you are not sure about the structure. | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| callbacks: Optional[list[Callback]] = None, | |||
| @@ -7,6 +7,7 @@ features: | |||
| - vision | |||
| - tool-call | |||
| - stream-tool-call | |||
| - document | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 200000 | |||
| @@ -7,6 +7,7 @@ features: | |||
| - vision | |||
| - tool-call | |||
| - stream-tool-call | |||
| - document | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 200000 | |||
| @@ -1,7 +1,7 @@ | |||
| import base64 | |||
| import io | |||
| import json | |||
| from collections.abc import Generator | |||
| from collections.abc import Generator, Sequence | |||
| from typing import Optional, Union, cast | |||
| import anthropic | |||
| @@ -21,9 +21,9 @@ from httpx import Timeout | |||
| from PIL import Image | |||
| from core.model_runtime.callbacks.base_callback import Callback | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import ( | |||
| from core.model_runtime.entities import ( | |||
| AssistantPromptMessage, | |||
| DocumentPromptMessageContent, | |||
| ImagePromptMessageContent, | |||
| PromptMessage, | |||
| PromptMessageContentType, | |||
| @@ -33,6 +33,7 @@ from core.model_runtime.entities.message_entities import ( | |||
| ToolPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.errors.invoke import ( | |||
| InvokeAuthorizationError, | |||
| InvokeBadRequestError, | |||
| @@ -86,10 +87,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): | |||
| self, | |||
| model: str, | |||
| credentials: dict, | |||
| prompt_messages: list[PromptMessage], | |||
| prompt_messages: Sequence[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stop: Optional[Sequence[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ) -> Union[LLMResult, Generator]: | |||
| @@ -130,9 +131,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): | |||
| # Add the new header for claude-3-5-sonnet-20240620 model | |||
| extra_headers = {} | |||
| if model == "claude-3-5-sonnet-20240620": | |||
| if model_parameters.get("max_tokens") > 4096: | |||
| if model_parameters.get("max_tokens", 0) > 4096: | |||
| extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" | |||
| if any( | |||
| isinstance(content, DocumentPromptMessageContent) | |||
| for prompt_message in prompt_messages | |||
| if isinstance(prompt_message.content, list) | |||
| for content in prompt_message.content | |||
| ): | |||
| extra_headers["anthropic-beta"] = "pdfs-2024-09-25" | |||
| if tools: | |||
| extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] | |||
| response = client.beta.tools.messages.create( | |||
| @@ -504,6 +513,21 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): | |||
| "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, | |||
| } | |||
| sub_messages.append(sub_message_dict) | |||
| elif isinstance(message_content, DocumentPromptMessageContent): | |||
| if message_content.mime_type != "application/pdf": | |||
| raise ValueError( | |||
| f"Unsupported document type {message_content.mime_type}, " | |||
| "only support application/pdf" | |||
| ) | |||
| sub_message_dict = { | |||
| "type": "document", | |||
| "source": { | |||
| "type": message_content.encode_format, | |||
| "media_type": message_content.mime_type, | |||
| "data": message_content.data, | |||
| }, | |||
| } | |||
| sub_messages.append(sub_message_dict) | |||
| prompt_message_dicts.append({"role": "user", "content": sub_messages}) | |||
| elif isinstance(message, AssistantPromptMessage): | |||
| message = cast(AssistantPromptMessage, message) | |||
| @@ -7,6 +7,7 @@ features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| - audio | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 128000 | |||
| @@ -1,3 +1,4 @@ | |||
| from collections.abc import Sequence | |||
| from typing import cast | |||
| from core.model_runtime.entities import ( | |||
| @@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode | |||
| class PromptMessageUtil: | |||
| @staticmethod | |||
| def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: | |||
| def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]: | |||
| """ | |||
| Prompt messages to prompt for saving. | |||
| :param model_mode: model mode | |||
| @@ -118,11 +118,11 @@ class FileSegment(Segment): | |||
| @property | |||
| def log(self) -> str: | |||
| return str(self.value) | |||
| return "" | |||
| @property | |||
| def text(self) -> str: | |||
| return str(self.value) | |||
| return "" | |||
| class ArrayAnySegment(ArraySegment): | |||
| @@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment): | |||
| for item in self.value: | |||
| items.append(item.markdown) | |||
| return "\n".join(items) | |||
| @property | |||
| def log(self) -> str: | |||
| return "" | |||
| @property | |||
| def text(self) -> str: | |||
| return "" | |||
| @@ -39,7 +39,14 @@ class VisionConfig(BaseModel): | |||
| class PromptConfig(BaseModel): | |||
| jinja2_variables: Optional[list[VariableSelector]] = None | |||
| jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) | |||
| @field_validator("jinja2_variables", mode="before") | |||
| @classmethod | |||
| def convert_none_jinja2_variables(cls, v: Any): | |||
| if v is None: | |||
| return [] | |||
| return v | |||
| class LLMNodeChatModelMessage(ChatModelMessage): | |||
| @@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): | |||
| class LLMNodeData(BaseNodeData): | |||
| model: ModelConfig | |||
| prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate | |||
| prompt_config: Optional[PromptConfig] = None | |||
| prompt_config: PromptConfig = Field(default_factory=PromptConfig) | |||
| memory: Optional[MemoryConfig] = None | |||
| context: ContextConfig | |||
| vision: VisionConfig = Field(default_factory=VisionConfig) | |||
| @field_validator("prompt_config", mode="before") | |||
| @classmethod | |||
| def convert_none_prompt_config(cls, v: Any): | |||
| if v is None: | |||
| return PromptConfig() | |||
| return v | |||
| @@ -24,3 +24,11 @@ class LLMModeRequiredError(LLMNodeError): | |||
| class NoPromptFoundError(LLMNodeError): | |||
| """Raised when no prompt is found in the LLM configuration.""" | |||
| class NotSupportedPromptTypeError(LLMNodeError): | |||
| """Raised when the prompt type is not supported.""" | |||
| class MemoryRolePrefixRequiredError(LLMNodeError): | |||
| """Raised when memory role prefix is required for completion model.""" | |||
| @@ -1,4 +1,5 @@ | |||
| import json | |||
| import logging | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import TYPE_CHECKING, Any, Optional, cast | |||
| @@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti | |||
| from core.entities.model_entities import ModelStatus | |||
| from core.entities.provider_entities import QuotaUnit | |||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||
| from core.file import FileType, file_manager | |||
| from core.helper.code_executor import CodeExecutor, CodeLanguage | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance, ModelManager | |||
| from core.model_runtime.entities import ( | |||
| AudioPromptMessageContent, | |||
| ImagePromptMessageContent, | |||
| PromptMessage, | |||
| PromptMessageContentType, | |||
| TextPromptMessageContent, | |||
| VideoPromptMessageContent, | |||
| ) | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.entities.message_entities import ( | |||
| AssistantPromptMessage, | |||
| PromptMessageRole, | |||
| SystemPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | |||
| from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.variables import ( | |||
| @@ -32,8 +38,9 @@ from core.variables import ( | |||
| ObjectSegment, | |||
| StringSegment, | |||
| ) | |||
| from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | |||
| from core.workflow.entities.variable_entities import VariableSelector | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.event import InNodeEvent | |||
| from core.workflow.nodes.base import BaseNode | |||
| @@ -62,14 +69,18 @@ from .exc import ( | |||
| InvalidVariableTypeError, | |||
| LLMModeRequiredError, | |||
| LLMNodeError, | |||
| MemoryRolePrefixRequiredError, | |||
| ModelNotExistError, | |||
| NoPromptFoundError, | |||
| NotSupportedPromptTypeError, | |||
| VariableNotFoundError, | |||
| ) | |||
| if TYPE_CHECKING: | |||
| from core.file.models import File | |||
| logger = logging.getLogger(__name__) | |||
| class LLMNode(BaseNode[LLMNodeData]): | |||
| _node_data_cls = LLMNodeData | |||
| @@ -123,17 +134,13 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| # fetch prompt messages | |||
| if self.node_data.memory: | |||
| query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) | |||
| if not query: | |||
| raise VariableNotFoundError("Query not found") | |||
| query = query.text | |||
| query = self.node_data.memory.query_prompt_template | |||
| else: | |||
| query = None | |||
| prompt_messages, stop = self._fetch_prompt_messages( | |||
| system_query=query, | |||
| inputs=inputs, | |||
| files=files, | |||
| user_query=query, | |||
| user_files=files, | |||
| context=context, | |||
| memory=memory, | |||
| model_config=model_config, | |||
| @@ -141,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| memory_config=self.node_data.memory, | |||
| vision_enabled=self.node_data.vision.enabled, | |||
| vision_detail=self.node_data.vision.configs.detail, | |||
| variable_pool=self.graph_runtime_state.variable_pool, | |||
| jinja2_variables=self.node_data.prompt_config.jinja2_variables, | |||
| ) | |||
| process_data = { | |||
| @@ -181,6 +190,17 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| ) | |||
| ) | |||
| return | |||
| except Exception as e: | |||
| logger.exception(f"Node {self.node_id} failed to run") | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=str(e), | |||
| inputs=node_inputs, | |||
| process_data=process_data, | |||
| ) | |||
| ) | |||
| return | |||
| outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} | |||
| @@ -203,8 +223,8 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| self, | |||
| node_data_model: ModelConfig, | |||
| model_instance: ModelInstance, | |||
| prompt_messages: list[PromptMessage], | |||
| stop: Optional[list[str]] = None, | |||
| prompt_messages: Sequence[PromptMessage], | |||
| stop: Optional[Sequence[str]] = None, | |||
| ) -> Generator[NodeEvent, None, None]: | |||
| db.session.close() | |||
| @@ -519,9 +539,8 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| def _fetch_prompt_messages( | |||
| self, | |||
| *, | |||
| system_query: str | None = None, | |||
| inputs: dict[str, str] | None = None, | |||
| files: Sequence["File"], | |||
| user_query: str | None = None, | |||
| user_files: Sequence["File"], | |||
| context: str | None = None, | |||
| memory: TokenBufferMemory | None = None, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| @@ -529,58 +548,146 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| memory_config: MemoryConfig | None = None, | |||
| vision_enabled: bool = False, | |||
| vision_detail: ImagePromptMessageContent.DETAIL, | |||
| ) -> tuple[list[PromptMessage], Optional[list[str]]]: | |||
| inputs = inputs or {} | |||
| prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) | |||
| prompt_messages = prompt_transform.get_prompt( | |||
| prompt_template=prompt_template, | |||
| inputs=inputs, | |||
| query=system_query or "", | |||
| files=files, | |||
| context=context, | |||
| memory_config=memory_config, | |||
| memory=memory, | |||
| model_config=model_config, | |||
| ) | |||
| stop = model_config.stop | |||
| variable_pool: VariablePool, | |||
| jinja2_variables: Sequence[VariableSelector], | |||
| ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: | |||
| prompt_messages = [] | |||
| if isinstance(prompt_template, list): | |||
| # For chat model | |||
| prompt_messages.extend( | |||
| _handle_list_messages( | |||
| messages=prompt_template, | |||
| context=context, | |||
| jinja2_variables=jinja2_variables, | |||
| variable_pool=variable_pool, | |||
| vision_detail_config=vision_detail, | |||
| ) | |||
| ) | |||
| # Get memory messages for chat mode | |||
| memory_messages = _handle_memory_chat_mode( | |||
| memory=memory, | |||
| memory_config=memory_config, | |||
| model_config=model_config, | |||
| ) | |||
| # Extend prompt_messages with memory messages | |||
| prompt_messages.extend(memory_messages) | |||
| # Add current query to the prompt messages | |||
| if user_query: | |||
| message = LLMNodeChatModelMessage( | |||
| text=user_query, | |||
| role=PromptMessageRole.USER, | |||
| edition_type="basic", | |||
| ) | |||
| prompt_messages.extend( | |||
| _handle_list_messages( | |||
| messages=[message], | |||
| context="", | |||
| jinja2_variables=[], | |||
| variable_pool=variable_pool, | |||
| vision_detail_config=vision_detail, | |||
| ) | |||
| ) | |||
| elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): | |||
| # For completion model | |||
| prompt_messages.extend( | |||
| _handle_completion_template( | |||
| template=prompt_template, | |||
| context=context, | |||
| jinja2_variables=jinja2_variables, | |||
| variable_pool=variable_pool, | |||
| ) | |||
| ) | |||
| # Get memory text for completion model | |||
| memory_text = _handle_memory_completion_mode( | |||
| memory=memory, | |||
| memory_config=memory_config, | |||
| model_config=model_config, | |||
| ) | |||
| # Insert histories into the prompt | |||
| prompt_content = prompt_messages[0].content | |||
| if "#histories#" in prompt_content: | |||
| prompt_content = prompt_content.replace("#histories#", memory_text) | |||
| else: | |||
| prompt_content = memory_text + "\n" + prompt_content | |||
| prompt_messages[0].content = prompt_content | |||
| # Add current query to the prompt message | |||
| if user_query: | |||
| prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query) | |||
| prompt_messages[0].content = prompt_content | |||
| else: | |||
| errmsg = f"Prompt type {type(prompt_template)} is not supported" | |||
| logger.warning(errmsg) | |||
| raise NotSupportedPromptTypeError(errmsg) | |||
| if vision_enabled and user_files: | |||
| file_prompts = [] | |||
| for file in user_files: | |||
| file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) | |||
| file_prompts.append(file_prompt) | |||
| if ( | |||
| len(prompt_messages) > 0 | |||
| and isinstance(prompt_messages[-1], UserPromptMessage) | |||
| and isinstance(prompt_messages[-1].content, list) | |||
| ): | |||
| prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts) | |||
| else: | |||
| prompt_messages.append(UserPromptMessage(content=file_prompts)) | |||
| # Filter prompt messages | |||
| filtered_prompt_messages = [] | |||
| for prompt_message in prompt_messages: | |||
| if prompt_message.is_empty(): | |||
| continue | |||
| if not isinstance(prompt_message.content, str): | |||
| if isinstance(prompt_message.content, list): | |||
| prompt_message_content = [] | |||
| for content_item in prompt_message.content or []: | |||
| # Skip image if vision is disabled | |||
| if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: | |||
| for content_item in prompt_message.content: | |||
| # Skip content if features are not defined | |||
| if not model_config.model_schema.features: | |||
| if content_item.type != PromptMessageContentType.TEXT: | |||
| continue | |||
| prompt_message_content.append(content_item) | |||
| continue | |||
| if isinstance(content_item, ImagePromptMessageContent): | |||
| # Override vision config if LLM node has vision config, | |||
| # cuz vision detail is related to the configuration from FileUpload feature. | |||
| content_item.detail = vision_detail | |||
| prompt_message_content.append(content_item) | |||
| elif isinstance( | |||
| content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent | |||
| # Skip content if corresponding feature is not supported | |||
| if ( | |||
| ( | |||
| content_item.type == PromptMessageContentType.IMAGE | |||
| and ModelFeature.VISION not in model_config.model_schema.features | |||
| ) | |||
| or ( | |||
| content_item.type == PromptMessageContentType.DOCUMENT | |||
| and ModelFeature.DOCUMENT not in model_config.model_schema.features | |||
| ) | |||
| or ( | |||
| content_item.type == PromptMessageContentType.VIDEO | |||
| and ModelFeature.VIDEO not in model_config.model_schema.features | |||
| ) | |||
| or ( | |||
| content_item.type == PromptMessageContentType.AUDIO | |||
| and ModelFeature.AUDIO not in model_config.model_schema.features | |||
| ) | |||
| ): | |||
| prompt_message_content.append(content_item) | |||
| if len(prompt_message_content) > 1: | |||
| prompt_message.content = prompt_message_content | |||
| elif ( | |||
| len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT | |||
| ): | |||
| continue | |||
| prompt_message_content.append(content_item) | |||
| if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: | |||
| prompt_message.content = prompt_message_content[0].data | |||
| else: | |||
| prompt_message.content = prompt_message_content | |||
| if prompt_message.is_empty(): | |||
| continue | |||
| filtered_prompt_messages.append(prompt_message) | |||
| if not filtered_prompt_messages: | |||
| if len(filtered_prompt_messages) == 0: | |||
| raise NoPromptFoundError( | |||
| "No prompt found in the LLM configuration. " | |||
| "Please ensure a prompt is properly configured before proceeding." | |||
| ) | |||
| stop = model_config.stop | |||
| return filtered_prompt_messages, stop | |||
| @classmethod | |||
| @@ -715,3 +822,198 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| } | |||
| }, | |||
| } | |||
| def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): | |||
| match role: | |||
| case PromptMessageRole.USER: | |||
| return UserPromptMessage(content=[TextPromptMessageContent(data=text)]) | |||
| case PromptMessageRole.ASSISTANT: | |||
| return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)]) | |||
| case PromptMessageRole.SYSTEM: | |||
| return SystemPromptMessage(content=[TextPromptMessageContent(data=text)]) | |||
| raise NotImplementedError(f"Role {role} is not supported") | |||
| def _render_jinja2_message( | |||
| *, | |||
| template: str, | |||
| jinjia2_variables: Sequence[VariableSelector], | |||
| variable_pool: VariablePool, | |||
| ): | |||
| if not template: | |||
| return "" | |||
| jinjia2_inputs = {} | |||
| for jinja2_variable in jinjia2_variables: | |||
| variable = variable_pool.get(jinja2_variable.value_selector) | |||
| jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" | |||
| code_execute_resp = CodeExecutor.execute_workflow_code_template( | |||
| language=CodeLanguage.JINJA2, | |||
| code=template, | |||
| inputs=jinjia2_inputs, | |||
| ) | |||
| result_text = code_execute_resp["result"] | |||
| return result_text | |||
| def _handle_list_messages( | |||
| *, | |||
| messages: Sequence[LLMNodeChatModelMessage], | |||
| context: Optional[str], | |||
| jinja2_variables: Sequence[VariableSelector], | |||
| variable_pool: VariablePool, | |||
| vision_detail_config: ImagePromptMessageContent.DETAIL, | |||
| ) -> Sequence[PromptMessage]: | |||
| prompt_messages = [] | |||
| for message in messages: | |||
| if message.edition_type == "jinja2": | |||
| result_text = _render_jinja2_message( | |||
| template=message.jinja2_text or "", | |||
| jinjia2_variables=jinja2_variables, | |||
| variable_pool=variable_pool, | |||
| ) | |||
| prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) | |||
| prompt_messages.append(prompt_message) | |||
| else: | |||
| # Get segment group from basic message | |||
| if context: | |||
| template = message.text.replace("{#context#}", context) | |||
| else: | |||
| template = message.text | |||
| segment_group = variable_pool.convert_template(template) | |||
| # Process segments for images | |||
| file_contents = [] | |||
| for segment in segment_group.value: | |||
| if isinstance(segment, ArrayFileSegment): | |||
| for file in segment.value: | |||
| if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: | |||
| file_content = file_manager.to_prompt_message_content( | |||
| file, image_detail_config=vision_detail_config | |||
| ) | |||
| file_contents.append(file_content) | |||
| if isinstance(segment, FileSegment): | |||
| file = segment.value | |||
| if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: | |||
| file_content = file_manager.to_prompt_message_content( | |||
| file, image_detail_config=vision_detail_config | |||
| ) | |||
| file_contents.append(file_content) | |||
| # Create message with text from all segments | |||
| plain_text = segment_group.text | |||
| if plain_text: | |||
| prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) | |||
| prompt_messages.append(prompt_message) | |||
| if file_contents: | |||
| # Create message with image contents | |||
| prompt_message = UserPromptMessage(content=file_contents) | |||
| prompt_messages.append(prompt_message) | |||
| return prompt_messages | |||
| def _calculate_rest_token( | |||
| *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity | |||
| ) -> int: | |||
| rest_tokens = 2000 | |||
| model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) | |||
| if model_context_tokens: | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=model_config.provider_model_bundle, model=model_config.model | |||
| ) | |||
| curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) | |||
| max_tokens = 0 | |||
| for parameter_rule in model_config.model_schema.parameter_rules: | |||
| if parameter_rule.name == "max_tokens" or ( | |||
| parameter_rule.use_template and parameter_rule.use_template == "max_tokens" | |||
| ): | |||
| max_tokens = ( | |||
| model_config.parameters.get(parameter_rule.name) | |||
| or model_config.parameters.get(str(parameter_rule.use_template)) | |||
| or 0 | |||
| ) | |||
| rest_tokens = model_context_tokens - max_tokens - curr_message_tokens | |||
| rest_tokens = max(rest_tokens, 0) | |||
| return rest_tokens | |||
| def _handle_memory_chat_mode( | |||
| *, | |||
| memory: TokenBufferMemory | None, | |||
| memory_config: MemoryConfig | None, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| ) -> Sequence[PromptMessage]: | |||
| memory_messages = [] | |||
| # Get messages from memory for chat model | |||
| if memory and memory_config: | |||
| rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) | |||
| memory_messages = memory.get_history_prompt_messages( | |||
| max_token_limit=rest_tokens, | |||
| message_limit=memory_config.window.size if memory_config.window.enabled else None, | |||
| ) | |||
| return memory_messages | |||
| def _handle_memory_completion_mode( | |||
| *, | |||
| memory: TokenBufferMemory | None, | |||
| memory_config: MemoryConfig | None, | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| ) -> str: | |||
| memory_text = "" | |||
| # Get history text from memory for completion model | |||
| if memory and memory_config: | |||
| rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) | |||
| if not memory_config.role_prefix: | |||
| raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") | |||
| memory_text = memory.get_history_prompt_text( | |||
| max_token_limit=rest_tokens, | |||
| message_limit=memory_config.window.size if memory_config.window.enabled else None, | |||
| human_prefix=memory_config.role_prefix.user, | |||
| ai_prefix=memory_config.role_prefix.assistant, | |||
| ) | |||
| return memory_text | |||
| def _handle_completion_template( | |||
| *, | |||
| template: LLMNodeCompletionModelPromptTemplate, | |||
| context: Optional[str], | |||
| jinja2_variables: Sequence[VariableSelector], | |||
| variable_pool: VariablePool, | |||
| ) -> Sequence[PromptMessage]: | |||
| """Handle completion template processing outside of LLMNode class. | |||
| Args: | |||
| template: The completion model prompt template | |||
| context: Optional context string | |||
| jinja2_variables: Variables for jinja2 template rendering | |||
| variable_pool: Variable pool for template conversion | |||
| Returns: | |||
| Sequence of prompt messages | |||
| """ | |||
| prompt_messages = [] | |||
| if template.edition_type == "jinja2": | |||
| result_text = _render_jinja2_message( | |||
| template=template.jinja2_text or "", | |||
| jinjia2_variables=jinja2_variables, | |||
| variable_pool=variable_pool, | |||
| ) | |||
| else: | |||
| if context: | |||
| template_text = template.text.replace("{#context#}", context) | |||
| else: | |||
| template_text = template.text | |||
| result_text = variable_pool.convert_template(template_text).text | |||
| prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) | |||
| prompt_messages.append(prompt_message) | |||
| return prompt_messages | |||
| @@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode): | |||
| ) | |||
| prompt_messages, stop = self._fetch_prompt_messages( | |||
| prompt_template=prompt_template, | |||
| system_query=query, | |||
| user_query=query, | |||
| memory=memory, | |||
| model_config=model_config, | |||
| files=files, | |||
| user_files=files, | |||
| vision_enabled=node_data.vision.enabled, | |||
| vision_detail=node_data.vision.configs.detail, | |||
| variable_pool=variable_pool, | |||
| jinja2_variables=[], | |||
| ) | |||
| # handle invoke result | |||
| @@ -2423,6 +2423,21 @@ files = [ | |||
| [package.extras] | |||
| test = ["pytest (>=6)"] | |||
| [[package]] | |||
| name = "faker" | |||
| version = "32.1.0" | |||
| description = "Faker is a Python package that generates fake data for you." | |||
| optional = false | |||
| python-versions = ">=3.8" | |||
| files = [ | |||
| {file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"}, | |||
| {file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"}, | |||
| ] | |||
| [package.dependencies] | |||
| python-dateutil = ">=2.4" | |||
| typing-extensions = "*" | |||
| [[package]] | |||
| name = "fal-client" | |||
| version = "0.5.6" | |||
| @@ -11041,4 +11056,4 @@ cffi = ["cffi (>=1.11)"] | |||
| [metadata] | |||
| lock-version = "2.0" | |||
| python-versions = ">=3.10,<3.13" | |||
| content-hash = "69a3f471f85dce9e5fb889f739e148a4a6d95aaf94081414503867c7157dba69" | |||
| content-hash = "d149b24ce7a203fa93eddbe8430d8ea7e5160a89c8d348b1b747c19899065639" | |||
| @@ -268,6 +268,7 @@ weaviate-client = "~3.21.0" | |||
| optional = true | |||
| [tool.poetry.group.dev.dependencies] | |||
| coverage = "~7.2.4" | |||
| faker = "~32.1.0" | |||
| pytest = "~8.3.2" | |||
| pytest-benchmark = "~4.0.0" | |||
| pytest-env = "~1.1.3" | |||
| @@ -11,7 +11,6 @@ from core.model_runtime.entities.message_entities import ( | |||
| ) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel | |||
| from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock | |||
| @pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) | |||
| @@ -4,29 +4,21 @@ import pytest | |||
| from core.model_runtime.entities.rerank_entities import RerankResult | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel | |||
| from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel | |||
| def test_validate_credentials(): | |||
| model = AzureAIStudioRerankModel() | |||
| model = AzureRerankModel() | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| model.validate_credentials( | |||
| model="azure-ai-studio-rerank-v1", | |||
| credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, | |||
| query="What is the capital of the United States?", | |||
| docs=[ | |||
| "Carson City is the capital city of the American state of Nevada. At the 2010 United States " | |||
| "Census, Carson City had a population of 55,274.", | |||
| "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " | |||
| "are a political division controlled by the United States. Its capital is Saipan.", | |||
| ], | |||
| score_threshold=0.8, | |||
| ) | |||
| def test_invoke_model(): | |||
| model = AzureAIStudioRerankModel() | |||
| model = AzureRerankModel() | |||
| result = model.invoke( | |||
| model="azure-ai-studio-rerank-v1", | |||
| @@ -1,125 +1,484 @@ | |||
| from collections.abc import Sequence | |||
| from typing import Optional | |||
| import pytest | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from configs import dify_config | |||
| from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity | |||
| from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle | |||
| from core.entities.provider_entities import CustomConfiguration, SystemConfiguration | |||
| from core.file import File, FileTransferMethod, FileType | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.message_entities import ( | |||
| AssistantPromptMessage, | |||
| ImagePromptMessageContent, | |||
| PromptMessage, | |||
| PromptMessageRole, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel | |||
| from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity | |||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.prompt.entities.advanced_prompt_entities import MemoryConfig | |||
| from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState | |||
| from core.workflow.nodes.answer import AnswerStreamGenerateRoute | |||
| from core.workflow.nodes.end import EndStreamParam | |||
| from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions | |||
| from core.workflow.nodes.llm.entities import ( | |||
| ContextConfig, | |||
| LLMNodeChatModelMessage, | |||
| LLMNodeData, | |||
| ModelConfig, | |||
| VisionConfig, | |||
| VisionConfigOptions, | |||
| ) | |||
| from core.workflow.nodes.llm.node import LLMNode | |||
| from models.enums import UserFrom | |||
| from models.provider import ProviderType | |||
| from models.workflow import WorkflowType | |||
| from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario | |||
| class TestLLMNode: | |||
| @pytest.fixture | |||
| def llm_node(self): | |||
| data = LLMNodeData( | |||
| title="Test LLM", | |||
| model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), | |||
| prompt_template=[], | |||
| memory=None, | |||
| context=ContextConfig(enabled=False), | |||
| vision=VisionConfig( | |||
| enabled=True, | |||
| configs=VisionConfigOptions( | |||
| variable_selector=["sys", "files"], | |||
| detail=ImagePromptMessageContent.DETAIL.HIGH, | |||
| ), | |||
| ), | |||
| ) | |||
| variable_pool = VariablePool( | |||
| system_variables={}, | |||
| user_inputs={}, | |||
| ) | |||
| node = LLMNode( | |||
| id="1", | |||
| config={ | |||
| "id": "1", | |||
| "data": data.model_dump(), | |||
| }, | |||
| graph_init_params=GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_id="1", | |||
| graph_config={}, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| call_depth=0, | |||
| class MockTokenBufferMemory: | |||
| def __init__(self, history_messages=None): | |||
| self.history_messages = history_messages or [] | |||
| def get_history_prompt_messages( | |||
| self, max_token_limit: int = 2000, message_limit: Optional[int] = None | |||
| ) -> Sequence[PromptMessage]: | |||
| if message_limit is not None: | |||
| return self.history_messages[-message_limit * 2 :] | |||
| return self.history_messages | |||
| @pytest.fixture | |||
| def llm_node(): | |||
| data = LLMNodeData( | |||
| title="Test LLM", | |||
| model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), | |||
| prompt_template=[], | |||
| memory=None, | |||
| context=ContextConfig(enabled=False), | |||
| vision=VisionConfig( | |||
| enabled=True, | |||
| configs=VisionConfigOptions( | |||
| variable_selector=["sys", "files"], | |||
| detail=ImagePromptMessageContent.DETAIL.HIGH, | |||
| ), | |||
| graph=Graph( | |||
| root_node_id="1", | |||
| answer_stream_generate_routes=AnswerStreamGenerateRoute( | |||
| answer_dependencies={}, | |||
| answer_generate_route={}, | |||
| ), | |||
| end_stream_param=EndStreamParam( | |||
| end_dependencies={}, | |||
| end_stream_variable_selector_mapping={}, | |||
| ), | |||
| ), | |||
| ) | |||
| variable_pool = VariablePool( | |||
| system_variables={}, | |||
| user_inputs={}, | |||
| ) | |||
| node = LLMNode( | |||
| id="1", | |||
| config={ | |||
| "id": "1", | |||
| "data": data.model_dump(), | |||
| }, | |||
| graph_init_params=GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_id="1", | |||
| graph_config={}, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| call_depth=0, | |||
| ), | |||
| graph=Graph( | |||
| root_node_id="1", | |||
| answer_stream_generate_routes=AnswerStreamGenerateRoute( | |||
| answer_dependencies={}, | |||
| answer_generate_route={}, | |||
| ), | |||
| graph_runtime_state=GraphRuntimeState( | |||
| variable_pool=variable_pool, | |||
| start_at=0, | |||
| end_stream_param=EndStreamParam( | |||
| end_dependencies={}, | |||
| end_stream_variable_selector_mapping={}, | |||
| ), | |||
| ) | |||
| return node | |||
| ), | |||
| graph_runtime_state=GraphRuntimeState( | |||
| variable_pool=variable_pool, | |||
| start_at=0, | |||
| ), | |||
| ) | |||
| return node | |||
| @pytest.fixture | |||
| def model_config(): | |||
| # Create actual provider and model type instances | |||
| model_provider_factory = ModelProviderFactory() | |||
| provider_instance = model_provider_factory.get_provider_instance("openai") | |||
| model_type_instance = provider_instance.get_model_instance(ModelType.LLM) | |||
| # Create a ProviderModelBundle | |||
| provider_model_bundle = ProviderModelBundle( | |||
| configuration=ProviderConfiguration( | |||
| tenant_id="1", | |||
| provider=provider_instance.get_provider_schema(), | |||
| preferred_provider_type=ProviderType.CUSTOM, | |||
| using_provider_type=ProviderType.CUSTOM, | |||
| system_configuration=SystemConfiguration(enabled=False), | |||
| custom_configuration=CustomConfiguration(provider=None), | |||
| model_settings=[], | |||
| ), | |||
| provider_instance=provider_instance, | |||
| model_type_instance=model_type_instance, | |||
| ) | |||
| def test_fetch_files_with_file_segment(self, llm_node): | |||
| file = File( | |||
| # Create and return a ModelConfigWithCredentialsEntity | |||
| return ModelConfigWithCredentialsEntity( | |||
| provider="openai", | |||
| model="gpt-3.5-turbo", | |||
| model_schema=AIModelEntity( | |||
| model="gpt-3.5-turbo", | |||
| label=I18nObject(en_US="GPT-3.5 Turbo"), | |||
| model_type=ModelType.LLM, | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_properties={}, | |||
| ), | |||
| mode="chat", | |||
| credentials={}, | |||
| parameters={}, | |||
| provider_model_bundle=provider_model_bundle, | |||
| ) | |||
| def test_fetch_files_with_file_segment(llm_node): | |||
| file = File( | |||
| id="1", | |||
| tenant_id="test", | |||
| type=FileType.IMAGE, | |||
| filename="test.jpg", | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="1", | |||
| ) | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| assert result == [file] | |||
| def test_fetch_files_with_array_file_segment(llm_node): | |||
| files = [ | |||
| File( | |||
| id="1", | |||
| tenant_id="test", | |||
| type=FileType.IMAGE, | |||
| filename="test.jpg", | |||
| filename="test1.jpg", | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="1", | |||
| ), | |||
| File( | |||
| id="2", | |||
| tenant_id="test", | |||
| type=FileType.IMAGE, | |||
| filename="test2.jpg", | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="2", | |||
| ), | |||
| ] | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| assert result == files | |||
| def test_fetch_files_with_none_segment(llm_node): | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| assert result == [] | |||
| def test_fetch_files_with_array_any_segment(llm_node): | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| assert result == [] | |||
| def test_fetch_files_with_non_existent_variable(llm_node): | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| assert result == [] | |||
| def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): | |||
| prompt_template = [] | |||
| llm_node.node_data.prompt_template = prompt_template | |||
| fake_vision_detail = faker.random_element( | |||
| [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] | |||
| ) | |||
| fake_remote_url = faker.url() | |||
| files = [ | |||
| File( | |||
| id="1", | |||
| tenant_id="test", | |||
| type=FileType.IMAGE, | |||
| filename="test1.jpg", | |||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||
| remote_url=fake_remote_url, | |||
| ) | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| assert result == [file] | |||
| def test_fetch_files_with_array_file_segment(self, llm_node): | |||
| files = [ | |||
| File( | |||
| id="1", | |||
| tenant_id="test", | |||
| type=FileType.IMAGE, | |||
| filename="test1.jpg", | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="1", | |||
| ), | |||
| File( | |||
| id="2", | |||
| tenant_id="test", | |||
| type=FileType.IMAGE, | |||
| filename="test2.jpg", | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="2", | |||
| ), | |||
| ] | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) | |||
| ] | |||
| fake_query = faker.sentence() | |||
| prompt_messages, _ = llm_node._fetch_prompt_messages( | |||
| user_query=fake_query, | |||
| user_files=files, | |||
| context=None, | |||
| memory=None, | |||
| model_config=model_config, | |||
| prompt_template=prompt_template, | |||
| memory_config=None, | |||
| vision_enabled=False, | |||
| vision_detail=fake_vision_detail, | |||
| variable_pool=llm_node.graph_runtime_state.variable_pool, | |||
| jinja2_variables=[], | |||
| ) | |||
| assert prompt_messages == [UserPromptMessage(content=fake_query)] | |||
| def test_fetch_prompt_messages__basic(faker, llm_node, model_config): | |||
| # Setup dify config | |||
| dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" | |||
| dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url" | |||
| # Generate fake values for prompt template | |||
| fake_assistant_prompt = faker.sentence() | |||
| fake_query = faker.sentence() | |||
| fake_context = faker.sentence() | |||
| fake_window_size = faker.random_int(min=1, max=3) | |||
| fake_vision_detail = faker.random_element( | |||
| [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] | |||
| ) | |||
| fake_remote_url = faker.url() | |||
| # Setup mock memory with history messages | |||
| mock_history = [ | |||
| UserPromptMessage(content=faker.sentence()), | |||
| AssistantPromptMessage(content=faker.sentence()), | |||
| UserPromptMessage(content=faker.sentence()), | |||
| AssistantPromptMessage(content=faker.sentence()), | |||
| UserPromptMessage(content=faker.sentence()), | |||
| AssistantPromptMessage(content=faker.sentence()), | |||
| ] | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| assert result == files | |||
| # Setup memory configuration | |||
| memory_config = MemoryConfig( | |||
| role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), | |||
| window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size), | |||
| query_prompt_template=None, | |||
| ) | |||
| def test_fetch_files_with_none_segment(self, llm_node): | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) | |||
| memory = MockTokenBufferMemory(history_messages=mock_history) | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| assert result == [] | |||
| # Test scenarios covering different file input combinations | |||
| test_scenarios = [ | |||
| LLMNodeTestScenario( | |||
| description="No files", | |||
| user_query=fake_query, | |||
| user_files=[], | |||
| features=[], | |||
| vision_enabled=False, | |||
| vision_detail=None, | |||
| window_size=fake_window_size, | |||
| prompt_template=[ | |||
| LLMNodeChatModelMessage( | |||
| text=fake_context, | |||
| role=PromptMessageRole.SYSTEM, | |||
| edition_type="basic", | |||
| ), | |||
| LLMNodeChatModelMessage( | |||
| text="{#context#}", | |||
| role=PromptMessageRole.USER, | |||
| edition_type="basic", | |||
| ), | |||
| LLMNodeChatModelMessage( | |||
| text=fake_assistant_prompt, | |||
| role=PromptMessageRole.ASSISTANT, | |||
| edition_type="basic", | |||
| ), | |||
| ], | |||
| expected_messages=[ | |||
| SystemPromptMessage(content=fake_context), | |||
| UserPromptMessage(content=fake_context), | |||
| AssistantPromptMessage(content=fake_assistant_prompt), | |||
| ] | |||
| + mock_history[fake_window_size * -2 :] | |||
| + [ | |||
| UserPromptMessage(content=fake_query), | |||
| ], | |||
| ), | |||
| LLMNodeTestScenario( | |||
| description="User files", | |||
| user_query=fake_query, | |||
| user_files=[ | |||
| File( | |||
| tenant_id="test", | |||
| type=FileType.IMAGE, | |||
| filename="test1.jpg", | |||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||
| remote_url=fake_remote_url, | |||
| ) | |||
| ], | |||
| vision_enabled=True, | |||
| vision_detail=fake_vision_detail, | |||
| features=[ModelFeature.VISION], | |||
| window_size=fake_window_size, | |||
| prompt_template=[ | |||
| LLMNodeChatModelMessage( | |||
| text=fake_context, | |||
| role=PromptMessageRole.SYSTEM, | |||
| edition_type="basic", | |||
| ), | |||
| LLMNodeChatModelMessage( | |||
| text="{#context#}", | |||
| role=PromptMessageRole.USER, | |||
| edition_type="basic", | |||
| ), | |||
| LLMNodeChatModelMessage( | |||
| text=fake_assistant_prompt, | |||
| role=PromptMessageRole.ASSISTANT, | |||
| edition_type="basic", | |||
| ), | |||
| ], | |||
| expected_messages=[ | |||
| SystemPromptMessage(content=fake_context), | |||
| UserPromptMessage(content=fake_context), | |||
| AssistantPromptMessage(content=fake_assistant_prompt), | |||
| ] | |||
| + mock_history[fake_window_size * -2 :] | |||
| + [ | |||
| UserPromptMessage( | |||
| content=[ | |||
| TextPromptMessageContent(data=fake_query), | |||
| ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), | |||
| ] | |||
| ), | |||
| ], | |||
| ), | |||
| LLMNodeTestScenario( | |||
| description="Prompt template with variable selector of File", | |||
| user_query=fake_query, | |||
| user_files=[], | |||
| vision_enabled=False, | |||
| vision_detail=fake_vision_detail, | |||
| features=[ModelFeature.VISION], | |||
| window_size=fake_window_size, | |||
| prompt_template=[ | |||
| LLMNodeChatModelMessage( | |||
| text="{{#input.image#}}", | |||
| role=PromptMessageRole.USER, | |||
| edition_type="basic", | |||
| ), | |||
| ], | |||
| expected_messages=[ | |||
| UserPromptMessage( | |||
| content=[ | |||
| ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), | |||
| ] | |||
| ), | |||
| ] | |||
| + mock_history[fake_window_size * -2 :] | |||
| + [UserPromptMessage(content=fake_query)], | |||
| file_variables={ | |||
| "input.image": File( | |||
| tenant_id="test", | |||
| type=FileType.IMAGE, | |||
| filename="test1.jpg", | |||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||
| remote_url=fake_remote_url, | |||
| ) | |||
| }, | |||
| ), | |||
| LLMNodeTestScenario( | |||
| description="Prompt template with variable selector of File without vision feature", | |||
| user_query=fake_query, | |||
| user_files=[], | |||
| vision_enabled=True, | |||
| vision_detail=fake_vision_detail, | |||
| features=[], | |||
| window_size=fake_window_size, | |||
| prompt_template=[ | |||
| LLMNodeChatModelMessage( | |||
| text="{{#input.image#}}", | |||
| role=PromptMessageRole.USER, | |||
| edition_type="basic", | |||
| ), | |||
| ], | |||
| expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)], | |||
| file_variables={ | |||
| "input.image": File( | |||
| tenant_id="test", | |||
| type=FileType.IMAGE, | |||
| filename="test1.jpg", | |||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||
| remote_url=fake_remote_url, | |||
| ) | |||
| }, | |||
| ), | |||
| LLMNodeTestScenario( | |||
| description="Prompt template with variable selector of File with video file and vision feature", | |||
| user_query=fake_query, | |||
| user_files=[], | |||
| vision_enabled=True, | |||
| vision_detail=fake_vision_detail, | |||
| features=[ModelFeature.VISION], | |||
| window_size=fake_window_size, | |||
| prompt_template=[ | |||
| LLMNodeChatModelMessage( | |||
| text="{{#input.image#}}", | |||
| role=PromptMessageRole.USER, | |||
| edition_type="basic", | |||
| ), | |||
| ], | |||
| expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)], | |||
| file_variables={ | |||
| "input.image": File( | |||
| tenant_id="test", | |||
| type=FileType.VIDEO, | |||
| filename="test1.mp4", | |||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||
| remote_url=fake_remote_url, | |||
| extension="mp4", | |||
| ) | |||
| }, | |||
| ), | |||
| ] | |||
| def test_fetch_files_with_array_any_segment(self, llm_node): | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) | |||
| for scenario in test_scenarios: | |||
| model_config.model_schema.features = scenario.features | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| assert result == [] | |||
| for k, v in scenario.file_variables.items(): | |||
| selector = k.split(".") | |||
| llm_node.graph_runtime_state.variable_pool.add(selector, v) | |||
| # Call the method under test | |||
| prompt_messages, _ = llm_node._fetch_prompt_messages( | |||
| user_query=scenario.user_query, | |||
| user_files=scenario.user_files, | |||
| context=fake_context, | |||
| memory=memory, | |||
| model_config=model_config, | |||
| prompt_template=scenario.prompt_template, | |||
| memory_config=memory_config, | |||
| vision_enabled=scenario.vision_enabled, | |||
| vision_detail=scenario.vision_detail, | |||
| variable_pool=llm_node.graph_runtime_state.variable_pool, | |||
| jinja2_variables=[], | |||
| ) | |||
| def test_fetch_files_with_non_existent_variable(self, llm_node): | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| assert result == [] | |||
| # Verify the result | |||
| assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}" | |||
| assert ( | |||
| prompt_messages == scenario.expected_messages | |||
| ), f"Message content mismatch in scenario: {scenario.description}" | |||
| @@ -0,0 +1,25 @@ | |||
| from collections.abc import Mapping, Sequence | |||
| from pydantic import BaseModel, Field | |||
| from core.file import File | |||
| from core.model_runtime.entities.message_entities import PromptMessage | |||
| from core.model_runtime.entities.model_entities import ModelFeature | |||
| from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage | |||
| class LLMNodeTestScenario(BaseModel): | |||
| """Test scenario for LLM node testing.""" | |||
| description: str = Field(..., description="Description of the test scenario") | |||
| user_query: str = Field(..., description="User query input") | |||
| user_files: Sequence[File] = Field(default_factory=list, description="List of user files") | |||
| vision_enabled: bool = Field(default=False, description="Whether vision is enabled") | |||
| vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") | |||
| features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") | |||
| window_size: int = Field(..., description="Window size for memory") | |||
| prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") | |||
| file_variables: Mapping[str, File | Sequence[File]] = Field( | |||
| default_factory=dict, description="List of file variables" | |||
| ) | |||
| expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") | |||
| @@ -160,6 +160,7 @@ const CodeEditor: FC<Props> = ({ | |||
| hideSearch | |||
| vars={availableVars} | |||
| onChange={handleSelectVar} | |||
| isSupportFileVar={false} | |||
| /> | |||
| </div> | |||
| )} | |||
| @@ -18,6 +18,7 @@ type Props = { | |||
| isSupportConstantValue?: boolean | |||
| onlyLeafNodeVar?: boolean | |||
| filterVar?: (payload: Var, valueSelector: ValueSelector) => boolean | |||
| isSupportFileVar?: boolean | |||
| } | |||
| const VarList: FC<Props> = ({ | |||
| @@ -29,6 +30,7 @@ const VarList: FC<Props> = ({ | |||
| isSupportConstantValue, | |||
| onlyLeafNodeVar, | |||
| filterVar, | |||
| isSupportFileVar = true, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| @@ -94,6 +96,7 @@ const VarList: FC<Props> = ({ | |||
| defaultVarKindType={item.variable_type} | |||
| onlyLeafNodeVar={onlyLeafNodeVar} | |||
| filterVar={filterVar} | |||
| isSupportFileVar={isSupportFileVar} | |||
| /> | |||
| {!readonly && ( | |||
| <RemoveButton | |||
| @@ -59,6 +59,7 @@ type Props = { | |||
| isInTable?: boolean | |||
| onRemove?: () => void | |||
| typePlaceHolder?: string | |||
| isSupportFileVar?: boolean | |||
| } | |||
| const VarReferencePicker: FC<Props> = ({ | |||
| @@ -81,6 +82,7 @@ const VarReferencePicker: FC<Props> = ({ | |||
| isInTable, | |||
| onRemove, | |||
| typePlaceHolder, | |||
| isSupportFileVar = true, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| const store = useStoreApi() | |||
| @@ -382,6 +384,7 @@ const VarReferencePicker: FC<Props> = ({ | |||
| vars={outputVars} | |||
| onChange={handleVarReferenceChange} | |||
| itemWidth={isAddBtnTrigger ? 260 : triggerWidth} | |||
| isSupportFileVar={isSupportFileVar} | |||
| /> | |||
| )} | |||
| </PortalToFollowElemContent> | |||
| @@ -8,11 +8,13 @@ type Props = { | |||
| vars: NodeOutPutVar[] | |||
| onChange: (value: ValueSelector, varDetail: Var) => void | |||
| itemWidth?: number | |||
| isSupportFileVar?: boolean | |||
| } | |||
| const VarReferencePopup: FC<Props> = ({ | |||
| vars, | |||
| onChange, | |||
| itemWidth, | |||
| isSupportFileVar = true, | |||
| }) => { | |||
| // max-h-[300px] overflow-y-auto todo: use portal to handle long list | |||
| return ( | |||
| @@ -24,7 +26,7 @@ const VarReferencePopup: FC<Props> = ({ | |||
| vars={vars} | |||
| onChange={onChange} | |||
| itemWidth={itemWidth} | |||
| isSupportFileVar | |||
| isSupportFileVar={isSupportFileVar} | |||
| /> | |||
| </div > | |||
| ) | |||
| @@ -89,6 +89,7 @@ const Panel: FC<NodePanelProps<CodeNodeType>> = ({ | |||
| list={inputs.variables} | |||
| onChange={handleVarListChange} | |||
| filterVar={filterVar} | |||
| isSupportFileVar={false} | |||
| /> | |||
| </Field> | |||
| <Split /> | |||
| @@ -144,6 +144,7 @@ const ConfigPromptItem: FC<Props> = ({ | |||
| onEditionTypeChange={onEditionTypeChange} | |||
| varList={varList} | |||
| handleAddVariable={handleAddVariable} | |||
| isSupportFileVar | |||
| /> | |||
| ) | |||
| } | |||
| @@ -67,6 +67,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({ | |||
| handleStop, | |||
| varInputs, | |||
| runResult, | |||
| filterJinjia2InputVar, | |||
| } = useConfig(id, data) | |||
| const model = inputs.model | |||
| @@ -194,7 +195,8 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({ | |||
| list={inputs.prompt_config?.jinja2_variables || []} | |||
| onChange={handleVarListChange} | |||
| onVarNameChange={handleVarNameChange} | |||
| filterVar={filterVar} | |||
| filterVar={filterJinjia2InputVar} | |||
| isSupportFileVar={false} | |||
| /> | |||
| </Field> | |||
| )} | |||
| @@ -233,6 +235,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({ | |||
| hasSetBlockStatus={hasSetBlockStatus} | |||
| nodesOutputVars={availableVars} | |||
| availableNodes={availableNodesWithParent} | |||
| isSupportFileVar | |||
| /> | |||
| {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && ( | |||
| @@ -278,11 +278,15 @@ const useConfig = (id: string, payload: LLMNodeType) => { | |||
| }, [inputs, setInputs]) | |||
| const filterInputVar = useCallback((varPayload: Var) => { | |||
| return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type) | |||
| }, []) | |||
| const filterJinjia2InputVar = useCallback((varPayload: Var) => { | |||
| return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) | |||
| }, []) | |||
| const filterMemoryPromptVar = useCallback((varPayload: Var) => { | |||
| return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) | |||
| return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type) | |||
| }, []) | |||
| const { | |||
| @@ -406,6 +410,7 @@ const useConfig = (id: string, payload: LLMNodeType) => { | |||
| handleRun, | |||
| handleStop, | |||
| runResult, | |||
| filterJinjia2InputVar, | |||
| } | |||
| } | |||
| @@ -64,6 +64,7 @@ const Panel: FC<NodePanelProps<TemplateTransformNodeType>> = ({ | |||
| onChange={handleVarListChange} | |||
| onVarNameChange={handleVarNameChange} | |||
| filterVar={filterVar} | |||
| isSupportFileVar={false} | |||
| /> | |||
| </Field> | |||
| <Split /> | |||