Co-authored-by: Joel <iamjoel007@gmail.com>tags/0.12.0
| # read from dotenv format config file | # read from dotenv format config file | ||||
| env_file=".env", | env_file=".env", | ||||
| env_file_encoding="utf-8", | env_file_encoding="utf-8", | ||||
| frozen=True, | |||||
| # ignore extra attributes | # ignore extra attributes | ||||
| extra="ignore", | extra="ignore", | ||||
| ) | ) |
| class ModelConfigConverter: | class ModelConfigConverter: | ||||
| @classmethod | @classmethod | ||||
| def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: | |||||
| def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity: | |||||
| """ | """ | ||||
| Convert app model config dict to entity. | Convert app model config dict to entity. | ||||
| :param app_config: app config | :param app_config: app config | ||||
| ) | ) | ||||
| if model_credentials is None: | if model_credentials is None: | ||||
| if not skip_check: | |||||
| raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") | |||||
| else: | |||||
| model_credentials = {} | |||||
| if not skip_check: | |||||
| # check model | |||||
| provider_model = provider_model_bundle.configuration.get_provider_model( | |||||
| model=model_config.model, model_type=ModelType.LLM | |||||
| ) | |||||
| if provider_model is None: | |||||
| model_name = model_config.model | |||||
| raise ValueError(f"Model {model_name} not exist.") | |||||
| if provider_model.status == ModelStatus.NO_CONFIGURE: | |||||
| raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") | |||||
| elif provider_model.status == ModelStatus.NO_PERMISSION: | |||||
| raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") | |||||
| elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: | |||||
| raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") | |||||
| raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") | |||||
| # check model | |||||
| provider_model = provider_model_bundle.configuration.get_provider_model( | |||||
| model=model_config.model, model_type=ModelType.LLM | |||||
| ) | |||||
| if provider_model is None: | |||||
| model_name = model_config.model | |||||
| raise ValueError(f"Model {model_name} not exist.") | |||||
| if provider_model.status == ModelStatus.NO_CONFIGURE: | |||||
| raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") | |||||
| elif provider_model.status == ModelStatus.NO_PERMISSION: | |||||
| raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") | |||||
| elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: | |||||
| raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") | |||||
| # model config | # model config | ||||
| completion_params = model_config.parameters | completion_params = model_config.parameters | ||||
| model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) | model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) | ||||
| if not skip_check and not model_schema: | |||||
| if not model_schema: | |||||
| raise ValueError(f"Model {model_name} not exist.") | raise ValueError(f"Model {model_name} not exist.") | ||||
| return ModelConfigWithCredentialsEntity( | return ModelConfigWithCredentialsEntity( |
| ).total_seconds() | ).total_seconds() | ||||
| db.session.commit() | db.session.commit() | ||||
| db.session.refresh(workflow_run) | |||||
| db.session.close() | db.session.close() | ||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| session.add(workflow_run) | |||||
| session.refresh(workflow_run) | |||||
| if trace_manager: | if trace_manager: | ||||
| trace_manager.add_trace_task( | trace_manager.add_trace_task( | ||||
| TraceTask( | TraceTask( |
| from configs import dify_config | from configs import dify_config | ||||
| from core.file import file_repository | from core.file import file_repository | ||||
| from core.helper import ssrf_proxy | from core.helper import ssrf_proxy | ||||
| from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent | |||||
| from core.model_runtime.entities import ( | |||||
| AudioPromptMessageContent, | |||||
| DocumentPromptMessageContent, | |||||
| ImagePromptMessageContent, | |||||
| VideoPromptMessageContent, | |||||
| ) | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| return file.remote_url | return file.remote_url | ||||
| case FileAttribute.EXTENSION: | case FileAttribute.EXTENSION: | ||||
| return file.extension | return file.extension | ||||
| case _: | |||||
| raise ValueError(f"Invalid file attribute: {attr}") | |||||
| def to_prompt_message_content( | def to_prompt_message_content( | ||||
| f: File, | f: File, | ||||
| /, | /, | ||||
| *, | *, | ||||
| image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, | |||||
| image_detail_config: ImagePromptMessageContent.DETAIL | None = None, | |||||
| ): | ): | ||||
| """ | |||||
| Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object. | |||||
| This function takes a File object and converts it to an appropriate PromptMessageContent | |||||
| object, which can be used as a prompt for image or audio-based AI models. | |||||
| Args: | |||||
| f (File): The File object to convert. | |||||
| detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts. | |||||
| If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW. | |||||
| Returns: | |||||
| Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level | |||||
| Raises: | |||||
| ValueError: If the file type is not supported or if required data is missing. | |||||
| """ | |||||
| match f.type: | match f.type: | ||||
| case FileType.IMAGE: | case FileType.IMAGE: | ||||
| image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW | |||||
| if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": | if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": | ||||
| data = _to_url(f) | data = _to_url(f) | ||||
| else: | else: | ||||
| return ImagePromptMessageContent(data=data, detail=image_detail_config) | return ImagePromptMessageContent(data=data, detail=image_detail_config) | ||||
| case FileType.AUDIO: | case FileType.AUDIO: | ||||
| encoded_string = _file_to_encoded_string(f) | |||||
| encoded_string = _get_encoded_string(f) | |||||
| if f.extension is None: | if f.extension is None: | ||||
| raise ValueError("Missing file extension") | raise ValueError("Missing file extension") | ||||
| return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) | return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) | ||||
| data = _to_url(f) | data = _to_url(f) | ||||
| else: | else: | ||||
| data = _to_base64_data_string(f) | data = _to_base64_data_string(f) | ||||
| if f.extension is None: | |||||
| raise ValueError("Missing file extension") | |||||
| return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) | return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) | ||||
| case FileType.DOCUMENT: | |||||
| data = _get_encoded_string(f) | |||||
| if f.mime_type is None: | |||||
| raise ValueError("Missing file mime_type") | |||||
| return DocumentPromptMessageContent( | |||||
| encode_format="base64", | |||||
| mime_type=f.mime_type, | |||||
| data=data, | |||||
| ) | |||||
| case _: | case _: | ||||
| raise ValueError("file type f.type is not supported") | |||||
| raise ValueError(f"file type {f.type} is not supported") | |||||
| def download(f: File, /): | def download(f: File, /): | ||||
| case FileTransferMethod.REMOTE_URL: | case FileTransferMethod.REMOTE_URL: | ||||
| response = ssrf_proxy.get(f.remote_url, follow_redirects=True) | response = ssrf_proxy.get(f.remote_url, follow_redirects=True) | ||||
| response.raise_for_status() | response.raise_for_status() | ||||
| content = response.content | |||||
| encoded_string = base64.b64encode(content).decode("utf-8") | |||||
| return encoded_string | |||||
| data = response.content | |||||
| case FileTransferMethod.LOCAL_FILE: | case FileTransferMethod.LOCAL_FILE: | ||||
| upload_file = file_repository.get_upload_file(session=db.session(), file=f) | upload_file = file_repository.get_upload_file(session=db.session(), file=f) | ||||
| data = _download_file_content(upload_file.key) | data = _download_file_content(upload_file.key) | ||||
| encoded_string = base64.b64encode(data).decode("utf-8") | |||||
| return encoded_string | |||||
| case FileTransferMethod.TOOL_FILE: | case FileTransferMethod.TOOL_FILE: | ||||
| tool_file = file_repository.get_tool_file(session=db.session(), file=f) | tool_file = file_repository.get_tool_file(session=db.session(), file=f) | ||||
| data = _download_file_content(tool_file.file_key) | data = _download_file_content(tool_file.file_key) | ||||
| encoded_string = base64.b64encode(data).decode("utf-8") | |||||
| return encoded_string | |||||
| case _: | |||||
| raise ValueError(f"Unsupported transfer method: {f.transfer_method}") | |||||
| encoded_string = base64.b64encode(data).decode("utf-8") | |||||
| return encoded_string | |||||
| def _to_base64_data_string(f: File, /): | def _to_base64_data_string(f: File, /): | ||||
| return f"data:{f.mime_type};base64,{encoded_string}" | return f"data:{f.mime_type};base64,{encoded_string}" | ||||
| def _file_to_encoded_string(f: File, /): | |||||
| match f.type: | |||||
| case FileType.IMAGE: | |||||
| return _to_base64_data_string(f) | |||||
| case FileType.VIDEO: | |||||
| return _to_base64_data_string(f) | |||||
| case FileType.AUDIO: | |||||
| return _get_encoded_string(f) | |||||
| case _: | |||||
| raise ValueError(f"file type {f.type} is not supported") | |||||
| def _to_url(f: File, /): | def _to_url(f: File, /): | ||||
| if f.transfer_method == FileTransferMethod.REMOTE_URL: | if f.transfer_method == FileTransferMethod.REMOTE_URL: | ||||
| if f.remote_url is None: | if f.remote_url is None: |
| from collections.abc import Sequence | |||||
| from typing import Optional | from typing import Optional | ||||
| from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | from core.app.app_config.features.file_upload.manager import FileUploadConfigManager | ||||
| def get_history_prompt_messages( | def get_history_prompt_messages( | ||||
| self, max_token_limit: int = 2000, message_limit: Optional[int] = None | self, max_token_limit: int = 2000, message_limit: Optional[int] = None | ||||
| ) -> list[PromptMessage]: | |||||
| ) -> Sequence[PromptMessage]: | |||||
| """ | """ | ||||
| Get history prompt messages. | Get history prompt messages. | ||||
| :param max_token_limit: max token limit | :param max_token_limit: max token limit |
| def invoke_llm( | def invoke_llm( | ||||
| self, | self, | ||||
| prompt_messages: list[PromptMessage], | |||||
| prompt_messages: Sequence[PromptMessage], | |||||
| model_parameters: Optional[dict] = None, | model_parameters: Optional[dict] = None, | ||||
| tools: Sequence[PromptMessageTool] | None = None, | tools: Sequence[PromptMessageTool] | None = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| callbacks: Optional[list[Callback]] = None, | callbacks: Optional[list[Callback]] = None, |
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from collections.abc import Sequence | |||||
| from typing import Optional | from typing import Optional | ||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| model_parameters: dict, | model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| ) -> None: | ) -> None: | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| model_parameters: dict, | model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| ): | ): | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| model_parameters: dict, | model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| ) -> None: | ) -> None: | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| model_parameters: dict, | model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| ) -> None: | ) -> None: |
| from .message_entities import ( | from .message_entities import ( | ||||
| AssistantPromptMessage, | AssistantPromptMessage, | ||||
| AudioPromptMessageContent, | AudioPromptMessageContent, | ||||
| DocumentPromptMessageContent, | |||||
| ImagePromptMessageContent, | ImagePromptMessageContent, | ||||
| PromptMessage, | PromptMessage, | ||||
| PromptMessageContent, | PromptMessageContent, | ||||
| "LLMResultChunk", | "LLMResultChunk", | ||||
| "LLMResultChunkDelta", | "LLMResultChunkDelta", | ||||
| "AudioPromptMessageContent", | "AudioPromptMessageContent", | ||||
| "DocumentPromptMessageContent", | |||||
| ] | ] |
| from abc import ABC | from abc import ABC | ||||
| from collections.abc import Sequence | |||||
| from enum import Enum | from enum import Enum | ||||
| from typing import Optional | |||||
| from typing import Literal, Optional | |||||
| from pydantic import BaseModel, Field, field_validator | from pydantic import BaseModel, Field, field_validator | ||||
| IMAGE = "image" | IMAGE = "image" | ||||
| AUDIO = "audio" | AUDIO = "audio" | ||||
| VIDEO = "video" | VIDEO = "video" | ||||
| DOCUMENT = "document" | |||||
| class PromptMessageContent(BaseModel): | class PromptMessageContent(BaseModel): | ||||
| detail: DETAIL = DETAIL.LOW | detail: DETAIL = DETAIL.LOW | ||||
| class DocumentPromptMessageContent(PromptMessageContent): | |||||
| type: PromptMessageContentType = PromptMessageContentType.DOCUMENT | |||||
| encode_format: Literal["base64"] | |||||
| mime_type: str | |||||
| data: str | |||||
| class PromptMessage(ABC, BaseModel): | class PromptMessage(ABC, BaseModel): | ||||
| """ | """ | ||||
| Model class for prompt message. | Model class for prompt message. | ||||
| """ | """ | ||||
| role: PromptMessageRole | role: PromptMessageRole | ||||
| content: Optional[str | list[PromptMessageContent]] = None | |||||
| content: Optional[str | Sequence[PromptMessageContent]] = None | |||||
| name: Optional[str] = None | name: Optional[str] = None | ||||
| def is_empty(self) -> bool: | def is_empty(self) -> bool: |
| AGENT_THOUGHT = "agent-thought" | AGENT_THOUGHT = "agent-thought" | ||||
| VISION = "vision" | VISION = "vision" | ||||
| STREAM_TOOL_CALL = "stream-tool-call" | STREAM_TOOL_CALL = "stream-tool-call" | ||||
| DOCUMENT = "document" | |||||
| VIDEO = "video" | |||||
| AUDIO = "audio" | |||||
| class DefaultParameterName(str, Enum): | class DefaultParameterName(str, Enum): |
| import re | import re | ||||
| import time | import time | ||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| from collections.abc import Generator, Mapping | |||||
| from collections.abc import Generator, Mapping, Sequence | |||||
| from typing import Optional, Union | from typing import Optional, Union | ||||
| from pydantic import ConfigDict | from pydantic import ConfigDict | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| model_parameters: Optional[dict] = None, | model_parameters: Optional[dict] = None, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| callbacks: Optional[list[Callback]] = None, | callbacks: Optional[list[Callback]] = None, | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| model_parameters: dict, | model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| callbacks: Optional[list[Callback]] = None, | callbacks: Optional[list[Callback]] = None, | ||||
| ) | ) | ||||
| model_parameters.pop("response_format") | model_parameters.pop("response_format") | ||||
| stop = stop or [] | |||||
| stop = list(stop) if stop is not None else [] | |||||
| stop.extend(["\n```", "```\n"]) | stop.extend(["\n```", "```\n"]) | ||||
| block_prompts = block_prompts.replace("{{block}}", code_block) | block_prompts = block_prompts.replace("{{block}}", code_block) | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| model_parameters: dict, | model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| callbacks: Optional[list[Callback]] = None, | callbacks: Optional[list[Callback]] = None, | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| model_parameters: dict, | model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| ) -> Union[LLMResult, Generator]: | ) -> Union[LLMResult, Generator]: | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| model_parameters: dict, | model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| callbacks: Optional[list[Callback]] = None, | callbacks: Optional[list[Callback]] = None, | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| model_parameters: dict, | model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| callbacks: Optional[list[Callback]] = None, | callbacks: Optional[list[Callback]] = None, | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| model_parameters: dict, | model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| callbacks: Optional[list[Callback]] = None, | callbacks: Optional[list[Callback]] = None, | ||||
| prompt_messages: list[PromptMessage], | prompt_messages: list[PromptMessage], | ||||
| model_parameters: dict, | model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| callbacks: Optional[list[Callback]] = None, | callbacks: Optional[list[Callback]] = None, |
| - vision | - vision | ||||
| - tool-call | - tool-call | ||||
| - stream-tool-call | - stream-tool-call | ||||
| - document | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 200000 | context_size: 200000 |
| - vision | - vision | ||||
| - tool-call | - tool-call | ||||
| - stream-tool-call | - stream-tool-call | ||||
| - document | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 200000 | context_size: 200000 |
| import base64 | import base64 | ||||
| import io | import io | ||||
| import json | import json | ||||
| from collections.abc import Generator | |||||
| from collections.abc import Generator, Sequence | |||||
| from typing import Optional, Union, cast | from typing import Optional, Union, cast | ||||
| import anthropic | import anthropic | ||||
| from PIL import Image | from PIL import Image | ||||
| from core.model_runtime.callbacks.base_callback import Callback | from core.model_runtime.callbacks.base_callback import Callback | ||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||||
| from core.model_runtime.entities.message_entities import ( | |||||
| from core.model_runtime.entities import ( | |||||
| AssistantPromptMessage, | AssistantPromptMessage, | ||||
| DocumentPromptMessageContent, | |||||
| ImagePromptMessageContent, | ImagePromptMessageContent, | ||||
| PromptMessage, | PromptMessage, | ||||
| PromptMessageContentType, | PromptMessageContentType, | ||||
| ToolPromptMessage, | ToolPromptMessage, | ||||
| UserPromptMessage, | UserPromptMessage, | ||||
| ) | ) | ||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||||
| from core.model_runtime.errors.invoke import ( | from core.model_runtime.errors.invoke import ( | ||||
| InvokeAuthorizationError, | InvokeAuthorizationError, | ||||
| InvokeBadRequestError, | InvokeBadRequestError, | ||||
| self, | self, | ||||
| model: str, | model: str, | ||||
| credentials: dict, | credentials: dict, | ||||
| prompt_messages: list[PromptMessage], | |||||
| prompt_messages: Sequence[PromptMessage], | |||||
| model_parameters: dict, | model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| stream: bool = True, | stream: bool = True, | ||||
| user: Optional[str] = None, | user: Optional[str] = None, | ||||
| ) -> Union[LLMResult, Generator]: | ) -> Union[LLMResult, Generator]: | ||||
| # Add the new header for claude-3-5-sonnet-20240620 model | # Add the new header for claude-3-5-sonnet-20240620 model | ||||
| extra_headers = {} | extra_headers = {} | ||||
| if model == "claude-3-5-sonnet-20240620": | if model == "claude-3-5-sonnet-20240620": | ||||
| if model_parameters.get("max_tokens") > 4096: | |||||
| if model_parameters.get("max_tokens", 0) > 4096: | |||||
| extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" | extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" | ||||
| if any( | |||||
| isinstance(content, DocumentPromptMessageContent) | |||||
| for prompt_message in prompt_messages | |||||
| if isinstance(prompt_message.content, list) | |||||
| for content in prompt_message.content | |||||
| ): | |||||
| extra_headers["anthropic-beta"] = "pdfs-2024-09-25" | |||||
| if tools: | if tools: | ||||
| extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] | extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] | ||||
| response = client.beta.tools.messages.create( | response = client.beta.tools.messages.create( | ||||
| "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, | "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, | ||||
| } | } | ||||
| sub_messages.append(sub_message_dict) | sub_messages.append(sub_message_dict) | ||||
| elif isinstance(message_content, DocumentPromptMessageContent): | |||||
| if message_content.mime_type != "application/pdf": | |||||
| raise ValueError( | |||||
| f"Unsupported document type {message_content.mime_type}, " | |||||
| "only support application/pdf" | |||||
| ) | |||||
| sub_message_dict = { | |||||
| "type": "document", | |||||
| "source": { | |||||
| "type": message_content.encode_format, | |||||
| "media_type": message_content.mime_type, | |||||
| "data": message_content.data, | |||||
| }, | |||||
| } | |||||
| sub_messages.append(sub_message_dict) | |||||
| prompt_message_dicts.append({"role": "user", "content": sub_messages}) | prompt_message_dicts.append({"role": "user", "content": sub_messages}) | ||||
| elif isinstance(message, AssistantPromptMessage): | elif isinstance(message, AssistantPromptMessage): | ||||
| message = cast(AssistantPromptMessage, message) | message = cast(AssistantPromptMessage, message) |
| - multi-tool-call | - multi-tool-call | ||||
| - agent-thought | - agent-thought | ||||
| - stream-tool-call | - stream-tool-call | ||||
| - audio | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 128000 | context_size: 128000 |
| from collections.abc import Sequence | |||||
| from typing import cast | from typing import cast | ||||
| from core.model_runtime.entities import ( | from core.model_runtime.entities import ( | ||||
| class PromptMessageUtil: | class PromptMessageUtil: | ||||
| @staticmethod | @staticmethod | ||||
| def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: | |||||
| def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]: | |||||
| """ | """ | ||||
| Prompt messages to prompt for saving. | Prompt messages to prompt for saving. | ||||
| :param model_mode: model mode | :param model_mode: model mode |
| @property | @property | ||||
| def log(self) -> str: | def log(self) -> str: | ||||
| return str(self.value) | |||||
| return "" | |||||
| @property | @property | ||||
| def text(self) -> str: | def text(self) -> str: | ||||
| return str(self.value) | |||||
| return "" | |||||
| class ArrayAnySegment(ArraySegment): | class ArrayAnySegment(ArraySegment): | ||||
| for item in self.value: | for item in self.value: | ||||
| items.append(item.markdown) | items.append(item.markdown) | ||||
| return "\n".join(items) | return "\n".join(items) | ||||
| @property | |||||
| def log(self) -> str: | |||||
| return "" | |||||
| @property | |||||
| def text(self) -> str: | |||||
| return "" |
| class PromptConfig(BaseModel): | class PromptConfig(BaseModel): | ||||
| jinja2_variables: Optional[list[VariableSelector]] = None | |||||
| jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) | |||||
| @field_validator("jinja2_variables", mode="before") | |||||
| @classmethod | |||||
| def convert_none_jinja2_variables(cls, v: Any): | |||||
| if v is None: | |||||
| return [] | |||||
| return v | |||||
| class LLMNodeChatModelMessage(ChatModelMessage): | class LLMNodeChatModelMessage(ChatModelMessage): | ||||
| class LLMNodeData(BaseNodeData): | class LLMNodeData(BaseNodeData): | ||||
| model: ModelConfig | model: ModelConfig | ||||
| prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate | prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate | ||||
| prompt_config: Optional[PromptConfig] = None | |||||
| prompt_config: PromptConfig = Field(default_factory=PromptConfig) | |||||
| memory: Optional[MemoryConfig] = None | memory: Optional[MemoryConfig] = None | ||||
| context: ContextConfig | context: ContextConfig | ||||
| vision: VisionConfig = Field(default_factory=VisionConfig) | vision: VisionConfig = Field(default_factory=VisionConfig) | ||||
| @field_validator("prompt_config", mode="before") | |||||
| @classmethod | |||||
| def convert_none_prompt_config(cls, v: Any): | |||||
| if v is None: | |||||
| return PromptConfig() | |||||
| return v |
| class NoPromptFoundError(LLMNodeError): | class NoPromptFoundError(LLMNodeError): | ||||
| """Raised when no prompt is found in the LLM configuration.""" | """Raised when no prompt is found in the LLM configuration.""" | ||||
| class NotSupportedPromptTypeError(LLMNodeError): | |||||
| """Raised when the prompt type is not supported.""" | |||||
| class MemoryRolePrefixRequiredError(LLMNodeError): | |||||
| """Raised when memory role prefix is required for completion model.""" |
| import json | import json | ||||
| import logging | |||||
| from collections.abc import Generator, Mapping, Sequence | from collections.abc import Generator, Mapping, Sequence | ||||
| from typing import TYPE_CHECKING, Any, Optional, cast | from typing import TYPE_CHECKING, Any, Optional, cast | ||||
| from core.entities.model_entities import ModelStatus | from core.entities.model_entities import ModelStatus | ||||
| from core.entities.provider_entities import QuotaUnit | from core.entities.provider_entities import QuotaUnit | ||||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | ||||
| from core.file import FileType, file_manager | |||||
| from core.helper.code_executor import CodeExecutor, CodeLanguage | |||||
| from core.memory.token_buffer_memory import TokenBufferMemory | from core.memory.token_buffer_memory import TokenBufferMemory | ||||
| from core.model_manager import ModelInstance, ModelManager | from core.model_manager import ModelInstance, ModelManager | ||||
| from core.model_runtime.entities import ( | from core.model_runtime.entities import ( | ||||
| AudioPromptMessageContent, | |||||
| ImagePromptMessageContent, | ImagePromptMessageContent, | ||||
| PromptMessage, | PromptMessage, | ||||
| PromptMessageContentType, | PromptMessageContentType, | ||||
| TextPromptMessageContent, | TextPromptMessageContent, | ||||
| VideoPromptMessageContent, | |||||
| ) | ) | ||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | ||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.model_runtime.entities.message_entities import ( | |||||
| AssistantPromptMessage, | |||||
| PromptMessageRole, | |||||
| SystemPromptMessage, | |||||
| UserPromptMessage, | |||||
| ) | |||||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType | |||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | |||||
| from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig | from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig | ||||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | from core.prompt.utils.prompt_message_util import PromptMessageUtil | ||||
| from core.variables import ( | from core.variables import ( | ||||
| ObjectSegment, | ObjectSegment, | ||||
| StringSegment, | StringSegment, | ||||
| ) | ) | ||||
| from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID | |||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | ||||
| from core.workflow.entities.variable_entities import VariableSelector | |||||
| from core.workflow.entities.variable_pool import VariablePool | |||||
| from core.workflow.enums import SystemVariableKey | from core.workflow.enums import SystemVariableKey | ||||
| from core.workflow.graph_engine.entities.event import InNodeEvent | from core.workflow.graph_engine.entities.event import InNodeEvent | ||||
| from core.workflow.nodes.base import BaseNode | from core.workflow.nodes.base import BaseNode | ||||
| InvalidVariableTypeError, | InvalidVariableTypeError, | ||||
| LLMModeRequiredError, | LLMModeRequiredError, | ||||
| LLMNodeError, | LLMNodeError, | ||||
| MemoryRolePrefixRequiredError, | |||||
| ModelNotExistError, | ModelNotExistError, | ||||
| NoPromptFoundError, | NoPromptFoundError, | ||||
| NotSupportedPromptTypeError, | |||||
| VariableNotFoundError, | VariableNotFoundError, | ||||
| ) | ) | ||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from core.file.models import File | from core.file.models import File | ||||
| logger = logging.getLogger(__name__) | |||||
| class LLMNode(BaseNode[LLMNodeData]): | class LLMNode(BaseNode[LLMNodeData]): | ||||
| _node_data_cls = LLMNodeData | _node_data_cls = LLMNodeData | ||||
| # fetch prompt messages | # fetch prompt messages | ||||
| if self.node_data.memory: | if self.node_data.memory: | ||||
| query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) | |||||
| if not query: | |||||
| raise VariableNotFoundError("Query not found") | |||||
| query = query.text | |||||
| query = self.node_data.memory.query_prompt_template | |||||
| else: | else: | ||||
| query = None | query = None | ||||
| prompt_messages, stop = self._fetch_prompt_messages( | prompt_messages, stop = self._fetch_prompt_messages( | ||||
| system_query=query, | |||||
| inputs=inputs, | |||||
| files=files, | |||||
| user_query=query, | |||||
| user_files=files, | |||||
| context=context, | context=context, | ||||
| memory=memory, | memory=memory, | ||||
| model_config=model_config, | model_config=model_config, | ||||
| memory_config=self.node_data.memory, | memory_config=self.node_data.memory, | ||||
| vision_enabled=self.node_data.vision.enabled, | vision_enabled=self.node_data.vision.enabled, | ||||
| vision_detail=self.node_data.vision.configs.detail, | vision_detail=self.node_data.vision.configs.detail, | ||||
| variable_pool=self.graph_runtime_state.variable_pool, | |||||
| jinja2_variables=self.node_data.prompt_config.jinja2_variables, | |||||
| ) | ) | ||||
| process_data = { | process_data = { | ||||
| ) | ) | ||||
| ) | ) | ||||
| return | return | ||||
| except Exception as e: | |||||
| logger.exception(f"Node {self.node_id} failed to run") | |||||
| yield RunCompletedEvent( | |||||
| run_result=NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.FAILED, | |||||
| error=str(e), | |||||
| inputs=node_inputs, | |||||
| process_data=process_data, | |||||
| ) | |||||
| ) | |||||
| return | |||||
| outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} | outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} | ||||
| self, | self, | ||||
| node_data_model: ModelConfig, | node_data_model: ModelConfig, | ||||
| model_instance: ModelInstance, | model_instance: ModelInstance, | ||||
| prompt_messages: list[PromptMessage], | |||||
| stop: Optional[list[str]] = None, | |||||
| prompt_messages: Sequence[PromptMessage], | |||||
| stop: Optional[Sequence[str]] = None, | |||||
| ) -> Generator[NodeEvent, None, None]: | ) -> Generator[NodeEvent, None, None]: | ||||
| db.session.close() | db.session.close() | ||||
| def _fetch_prompt_messages( | def _fetch_prompt_messages( | ||||
| self, | self, | ||||
| *, | *, | ||||
| system_query: str | None = None, | |||||
| inputs: dict[str, str] | None = None, | |||||
| files: Sequence["File"], | |||||
| user_query: str | None = None, | |||||
| user_files: Sequence["File"], | |||||
| context: str | None = None, | context: str | None = None, | ||||
| memory: TokenBufferMemory | None = None, | memory: TokenBufferMemory | None = None, | ||||
| model_config: ModelConfigWithCredentialsEntity, | model_config: ModelConfigWithCredentialsEntity, | ||||
| memory_config: MemoryConfig | None = None, | memory_config: MemoryConfig | None = None, | ||||
| vision_enabled: bool = False, | vision_enabled: bool = False, | ||||
| vision_detail: ImagePromptMessageContent.DETAIL, | vision_detail: ImagePromptMessageContent.DETAIL, | ||||
| ) -> tuple[list[PromptMessage], Optional[list[str]]]: | |||||
| inputs = inputs or {} | |||||
| prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) | |||||
| prompt_messages = prompt_transform.get_prompt( | |||||
| prompt_template=prompt_template, | |||||
| inputs=inputs, | |||||
| query=system_query or "", | |||||
| files=files, | |||||
| context=context, | |||||
| memory_config=memory_config, | |||||
| memory=memory, | |||||
| model_config=model_config, | |||||
| ) | |||||
| stop = model_config.stop | |||||
| variable_pool: VariablePool, | |||||
| jinja2_variables: Sequence[VariableSelector], | |||||
| ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: | |||||
| prompt_messages = [] | |||||
| if isinstance(prompt_template, list): | |||||
| # For chat model | |||||
| prompt_messages.extend( | |||||
| _handle_list_messages( | |||||
| messages=prompt_template, | |||||
| context=context, | |||||
| jinja2_variables=jinja2_variables, | |||||
| variable_pool=variable_pool, | |||||
| vision_detail_config=vision_detail, | |||||
| ) | |||||
| ) | |||||
| # Get memory messages for chat mode | |||||
| memory_messages = _handle_memory_chat_mode( | |||||
| memory=memory, | |||||
| memory_config=memory_config, | |||||
| model_config=model_config, | |||||
| ) | |||||
| # Extend prompt_messages with memory messages | |||||
| prompt_messages.extend(memory_messages) | |||||
| # Add current query to the prompt messages | |||||
| if user_query: | |||||
| message = LLMNodeChatModelMessage( | |||||
| text=user_query, | |||||
| role=PromptMessageRole.USER, | |||||
| edition_type="basic", | |||||
| ) | |||||
| prompt_messages.extend( | |||||
| _handle_list_messages( | |||||
| messages=[message], | |||||
| context="", | |||||
| jinja2_variables=[], | |||||
| variable_pool=variable_pool, | |||||
| vision_detail_config=vision_detail, | |||||
| ) | |||||
| ) | |||||
| elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): | |||||
| # For completion model | |||||
| prompt_messages.extend( | |||||
| _handle_completion_template( | |||||
| template=prompt_template, | |||||
| context=context, | |||||
| jinja2_variables=jinja2_variables, | |||||
| variable_pool=variable_pool, | |||||
| ) | |||||
| ) | |||||
| # Get memory text for completion model | |||||
| memory_text = _handle_memory_completion_mode( | |||||
| memory=memory, | |||||
| memory_config=memory_config, | |||||
| model_config=model_config, | |||||
| ) | |||||
| # Insert histories into the prompt | |||||
| prompt_content = prompt_messages[0].content | |||||
| if "#histories#" in prompt_content: | |||||
| prompt_content = prompt_content.replace("#histories#", memory_text) | |||||
| else: | |||||
| prompt_content = memory_text + "\n" + prompt_content | |||||
| prompt_messages[0].content = prompt_content | |||||
| # Add current query to the prompt message | |||||
| if user_query: | |||||
| prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query) | |||||
| prompt_messages[0].content = prompt_content | |||||
| else: | |||||
| errmsg = f"Prompt type {type(prompt_template)} is not supported" | |||||
| logger.warning(errmsg) | |||||
| raise NotSupportedPromptTypeError(errmsg) | |||||
| if vision_enabled and user_files: | |||||
| file_prompts = [] | |||||
| for file in user_files: | |||||
| file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) | |||||
| file_prompts.append(file_prompt) | |||||
| if ( | |||||
| len(prompt_messages) > 0 | |||||
| and isinstance(prompt_messages[-1], UserPromptMessage) | |||||
| and isinstance(prompt_messages[-1].content, list) | |||||
| ): | |||||
| prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts) | |||||
| else: | |||||
| prompt_messages.append(UserPromptMessage(content=file_prompts)) | |||||
| # Filter prompt messages | |||||
| filtered_prompt_messages = [] | filtered_prompt_messages = [] | ||||
| for prompt_message in prompt_messages: | for prompt_message in prompt_messages: | ||||
| if prompt_message.is_empty(): | |||||
| continue | |||||
| if not isinstance(prompt_message.content, str): | |||||
| if isinstance(prompt_message.content, list): | |||||
| prompt_message_content = [] | prompt_message_content = [] | ||||
| for content_item in prompt_message.content or []: | |||||
| # Skip image if vision is disabled | |||||
| if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: | |||||
| for content_item in prompt_message.content: | |||||
| # Skip content if features are not defined | |||||
| if not model_config.model_schema.features: | |||||
| if content_item.type != PromptMessageContentType.TEXT: | |||||
| continue | |||||
| prompt_message_content.append(content_item) | |||||
| continue | continue | ||||
| if isinstance(content_item, ImagePromptMessageContent): | |||||
| # Override vision config if LLM node has vision config, | |||||
| # cuz vision detail is related to the configuration from FileUpload feature. | |||||
| content_item.detail = vision_detail | |||||
| prompt_message_content.append(content_item) | |||||
| elif isinstance( | |||||
| content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent | |||||
| # Skip content if corresponding feature is not supported | |||||
| if ( | |||||
| ( | |||||
| content_item.type == PromptMessageContentType.IMAGE | |||||
| and ModelFeature.VISION not in model_config.model_schema.features | |||||
| ) | |||||
| or ( | |||||
| content_item.type == PromptMessageContentType.DOCUMENT | |||||
| and ModelFeature.DOCUMENT not in model_config.model_schema.features | |||||
| ) | |||||
| or ( | |||||
| content_item.type == PromptMessageContentType.VIDEO | |||||
| and ModelFeature.VIDEO not in model_config.model_schema.features | |||||
| ) | |||||
| or ( | |||||
| content_item.type == PromptMessageContentType.AUDIO | |||||
| and ModelFeature.AUDIO not in model_config.model_schema.features | |||||
| ) | |||||
| ): | ): | ||||
| prompt_message_content.append(content_item) | |||||
| if len(prompt_message_content) > 1: | |||||
| prompt_message.content = prompt_message_content | |||||
| elif ( | |||||
| len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT | |||||
| ): | |||||
| continue | |||||
| prompt_message_content.append(content_item) | |||||
| if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: | |||||
| prompt_message.content = prompt_message_content[0].data | prompt_message.content = prompt_message_content[0].data | ||||
| else: | |||||
| prompt_message.content = prompt_message_content | |||||
| if prompt_message.is_empty(): | |||||
| continue | |||||
| filtered_prompt_messages.append(prompt_message) | filtered_prompt_messages.append(prompt_message) | ||||
| if not filtered_prompt_messages: | |||||
| if len(filtered_prompt_messages) == 0: | |||||
| raise NoPromptFoundError( | raise NoPromptFoundError( | ||||
| "No prompt found in the LLM configuration. " | "No prompt found in the LLM configuration. " | ||||
| "Please ensure a prompt is properly configured before proceeding." | "Please ensure a prompt is properly configured before proceeding." | ||||
| ) | ) | ||||
| stop = model_config.stop | |||||
| return filtered_prompt_messages, stop | return filtered_prompt_messages, stop | ||||
| @classmethod | @classmethod | ||||
| } | } | ||||
| }, | }, | ||||
| } | } | ||||
| def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): | |||||
| match role: | |||||
| case PromptMessageRole.USER: | |||||
| return UserPromptMessage(content=[TextPromptMessageContent(data=text)]) | |||||
| case PromptMessageRole.ASSISTANT: | |||||
| return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)]) | |||||
| case PromptMessageRole.SYSTEM: | |||||
| return SystemPromptMessage(content=[TextPromptMessageContent(data=text)]) | |||||
| raise NotImplementedError(f"Role {role} is not supported") | |||||
| def _render_jinja2_message( | |||||
| *, | |||||
| template: str, | |||||
| jinjia2_variables: Sequence[VariableSelector], | |||||
| variable_pool: VariablePool, | |||||
| ): | |||||
| if not template: | |||||
| return "" | |||||
| jinjia2_inputs = {} | |||||
| for jinja2_variable in jinjia2_variables: | |||||
| variable = variable_pool.get(jinja2_variable.value_selector) | |||||
| jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" | |||||
| code_execute_resp = CodeExecutor.execute_workflow_code_template( | |||||
| language=CodeLanguage.JINJA2, | |||||
| code=template, | |||||
| inputs=jinjia2_inputs, | |||||
| ) | |||||
| result_text = code_execute_resp["result"] | |||||
| return result_text | |||||
| def _handle_list_messages( | |||||
| *, | |||||
| messages: Sequence[LLMNodeChatModelMessage], | |||||
| context: Optional[str], | |||||
| jinja2_variables: Sequence[VariableSelector], | |||||
| variable_pool: VariablePool, | |||||
| vision_detail_config: ImagePromptMessageContent.DETAIL, | |||||
| ) -> Sequence[PromptMessage]: | |||||
| prompt_messages = [] | |||||
| for message in messages: | |||||
| if message.edition_type == "jinja2": | |||||
| result_text = _render_jinja2_message( | |||||
| template=message.jinja2_text or "", | |||||
| jinjia2_variables=jinja2_variables, | |||||
| variable_pool=variable_pool, | |||||
| ) | |||||
| prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) | |||||
| prompt_messages.append(prompt_message) | |||||
| else: | |||||
| # Get segment group from basic message | |||||
| if context: | |||||
| template = message.text.replace("{#context#}", context) | |||||
| else: | |||||
| template = message.text | |||||
| segment_group = variable_pool.convert_template(template) | |||||
| # Process segments for images | |||||
| file_contents = [] | |||||
| for segment in segment_group.value: | |||||
| if isinstance(segment, ArrayFileSegment): | |||||
| for file in segment.value: | |||||
| if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: | |||||
| file_content = file_manager.to_prompt_message_content( | |||||
| file, image_detail_config=vision_detail_config | |||||
| ) | |||||
| file_contents.append(file_content) | |||||
| if isinstance(segment, FileSegment): | |||||
| file = segment.value | |||||
| if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: | |||||
| file_content = file_manager.to_prompt_message_content( | |||||
| file, image_detail_config=vision_detail_config | |||||
| ) | |||||
| file_contents.append(file_content) | |||||
| # Create message with text from all segments | |||||
| plain_text = segment_group.text | |||||
| if plain_text: | |||||
| prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) | |||||
| prompt_messages.append(prompt_message) | |||||
| if file_contents: | |||||
| # Create message with image contents | |||||
| prompt_message = UserPromptMessage(content=file_contents) | |||||
| prompt_messages.append(prompt_message) | |||||
| return prompt_messages | |||||
| def _calculate_rest_token( | |||||
| *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity | |||||
| ) -> int: | |||||
| rest_tokens = 2000 | |||||
| model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) | |||||
| if model_context_tokens: | |||||
| model_instance = ModelInstance( | |||||
| provider_model_bundle=model_config.provider_model_bundle, model=model_config.model | |||||
| ) | |||||
| curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) | |||||
| max_tokens = 0 | |||||
| for parameter_rule in model_config.model_schema.parameter_rules: | |||||
| if parameter_rule.name == "max_tokens" or ( | |||||
| parameter_rule.use_template and parameter_rule.use_template == "max_tokens" | |||||
| ): | |||||
| max_tokens = ( | |||||
| model_config.parameters.get(parameter_rule.name) | |||||
| or model_config.parameters.get(str(parameter_rule.use_template)) | |||||
| or 0 | |||||
| ) | |||||
| rest_tokens = model_context_tokens - max_tokens - curr_message_tokens | |||||
| rest_tokens = max(rest_tokens, 0) | |||||
| return rest_tokens | |||||
| def _handle_memory_chat_mode( | |||||
| *, | |||||
| memory: TokenBufferMemory | None, | |||||
| memory_config: MemoryConfig | None, | |||||
| model_config: ModelConfigWithCredentialsEntity, | |||||
| ) -> Sequence[PromptMessage]: | |||||
| memory_messages = [] | |||||
| # Get messages from memory for chat model | |||||
| if memory and memory_config: | |||||
| rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) | |||||
| memory_messages = memory.get_history_prompt_messages( | |||||
| max_token_limit=rest_tokens, | |||||
| message_limit=memory_config.window.size if memory_config.window.enabled else None, | |||||
| ) | |||||
| return memory_messages | |||||
| def _handle_memory_completion_mode( | |||||
| *, | |||||
| memory: TokenBufferMemory | None, | |||||
| memory_config: MemoryConfig | None, | |||||
| model_config: ModelConfigWithCredentialsEntity, | |||||
| ) -> str: | |||||
| memory_text = "" | |||||
| # Get history text from memory for completion model | |||||
| if memory and memory_config: | |||||
| rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) | |||||
| if not memory_config.role_prefix: | |||||
| raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") | |||||
| memory_text = memory.get_history_prompt_text( | |||||
| max_token_limit=rest_tokens, | |||||
| message_limit=memory_config.window.size if memory_config.window.enabled else None, | |||||
| human_prefix=memory_config.role_prefix.user, | |||||
| ai_prefix=memory_config.role_prefix.assistant, | |||||
| ) | |||||
| return memory_text | |||||
| def _handle_completion_template( | |||||
| *, | |||||
| template: LLMNodeCompletionModelPromptTemplate, | |||||
| context: Optional[str], | |||||
| jinja2_variables: Sequence[VariableSelector], | |||||
| variable_pool: VariablePool, | |||||
| ) -> Sequence[PromptMessage]: | |||||
| """Handle completion template processing outside of LLMNode class. | |||||
| Args: | |||||
| template: The completion model prompt template | |||||
| context: Optional context string | |||||
| jinja2_variables: Variables for jinja2 template rendering | |||||
| variable_pool: Variable pool for template conversion | |||||
| Returns: | |||||
| Sequence of prompt messages | |||||
| """ | |||||
| prompt_messages = [] | |||||
| if template.edition_type == "jinja2": | |||||
| result_text = _render_jinja2_message( | |||||
| template=template.jinja2_text or "", | |||||
| jinjia2_variables=jinja2_variables, | |||||
| variable_pool=variable_pool, | |||||
| ) | |||||
| else: | |||||
| if context: | |||||
| template_text = template.text.replace("{#context#}", context) | |||||
| else: | |||||
| template_text = template.text | |||||
| result_text = variable_pool.convert_template(template_text).text | |||||
| prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) | |||||
| prompt_messages.append(prompt_message) | |||||
| return prompt_messages |
| ) | ) | ||||
| prompt_messages, stop = self._fetch_prompt_messages( | prompt_messages, stop = self._fetch_prompt_messages( | ||||
| prompt_template=prompt_template, | prompt_template=prompt_template, | ||||
| system_query=query, | |||||
| user_query=query, | |||||
| memory=memory, | memory=memory, | ||||
| model_config=model_config, | model_config=model_config, | ||||
| files=files, | |||||
| user_files=files, | |||||
| vision_enabled=node_data.vision.enabled, | vision_enabled=node_data.vision.enabled, | ||||
| vision_detail=node_data.vision.configs.detail, | vision_detail=node_data.vision.configs.detail, | ||||
| variable_pool=variable_pool, | |||||
| jinja2_variables=[], | |||||
| ) | ) | ||||
| # handle invoke result | # handle invoke result |
| [package.extras] | [package.extras] | ||||
| test = ["pytest (>=6)"] | test = ["pytest (>=6)"] | ||||
| [[package]] | |||||
| name = "faker" | |||||
| version = "32.1.0" | |||||
| description = "Faker is a Python package that generates fake data for you." | |||||
| optional = false | |||||
| python-versions = ">=3.8" | |||||
| files = [ | |||||
| {file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"}, | |||||
| {file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"}, | |||||
| ] | |||||
| [package.dependencies] | |||||
| python-dateutil = ">=2.4" | |||||
| typing-extensions = "*" | |||||
| [[package]] | [[package]] | ||||
| name = "fal-client" | name = "fal-client" | ||||
| version = "0.5.6" | version = "0.5.6" | ||||
| [metadata] | [metadata] | ||||
| lock-version = "2.0" | lock-version = "2.0" | ||||
| python-versions = ">=3.10,<3.13" | python-versions = ">=3.10,<3.13" | ||||
| content-hash = "69a3f471f85dce9e5fb889f739e148a4a6d95aaf94081414503867c7157dba69" | |||||
| content-hash = "d149b24ce7a203fa93eddbe8430d8ea7e5160a89c8d348b1b747c19899065639" |
| optional = true | optional = true | ||||
| [tool.poetry.group.dev.dependencies] | [tool.poetry.group.dev.dependencies] | ||||
| coverage = "~7.2.4" | coverage = "~7.2.4" | ||||
| faker = "~32.1.0" | |||||
| pytest = "~8.3.2" | pytest = "~8.3.2" | ||||
| pytest-benchmark = "~4.0.0" | pytest-benchmark = "~4.0.0" | ||||
| pytest-env = "~1.1.3" | pytest-env = "~1.1.3" |
| ) | ) | ||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel | from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel | ||||
| from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock | |||||
| @pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) | @pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) |
| from core.model_runtime.entities.rerank_entities import RerankResult | from core.model_runtime.entities.rerank_entities import RerankResult | ||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel | |||||
| from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel | |||||
| def test_validate_credentials(): | def test_validate_credentials(): | ||||
| model = AzureAIStudioRerankModel() | |||||
| model = AzureRerankModel() | |||||
| with pytest.raises(CredentialsValidateFailedError): | with pytest.raises(CredentialsValidateFailedError): | ||||
| model.validate_credentials( | model.validate_credentials( | ||||
| model="azure-ai-studio-rerank-v1", | model="azure-ai-studio-rerank-v1", | ||||
| credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, | credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, | ||||
| query="What is the capital of the United States?", | |||||
| docs=[ | |||||
| "Carson City is the capital city of the American state of Nevada. At the 2010 United States " | |||||
| "Census, Carson City had a population of 55,274.", | |||||
| "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " | |||||
| "are a political division controlled by the United States. Its capital is Saipan.", | |||||
| ], | |||||
| score_threshold=0.8, | |||||
| ) | ) | ||||
| def test_invoke_model(): | def test_invoke_model(): | ||||
| model = AzureAIStudioRerankModel() | |||||
| model = AzureRerankModel() | |||||
| result = model.invoke( | result = model.invoke( | ||||
| model="azure-ai-studio-rerank-v1", | model="azure-ai-studio-rerank-v1", |
| from collections.abc import Sequence | |||||
| from typing import Optional | |||||
| import pytest | import pytest | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||||
| from configs import dify_config | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity | |||||
| from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle | |||||
| from core.entities.provider_entities import CustomConfiguration, SystemConfiguration | |||||
| from core.file import File, FileTransferMethod, FileType | from core.file import File, FileTransferMethod, FileType | ||||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||||
| from core.model_runtime.entities.common_entities import I18nObject | |||||
| from core.model_runtime.entities.message_entities import ( | |||||
| AssistantPromptMessage, | |||||
| ImagePromptMessageContent, | |||||
| PromptMessage, | |||||
| PromptMessageRole, | |||||
| SystemPromptMessage, | |||||
| TextPromptMessageContent, | |||||
| UserPromptMessage, | |||||
| ) | |||||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel | |||||
| from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity | |||||
| from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||||
| from core.prompt.entities.advanced_prompt_entities import MemoryConfig | |||||
| from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment | from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment | ||||
| from core.workflow.entities.variable_pool import VariablePool | from core.workflow.entities.variable_pool import VariablePool | ||||
| from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState | from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState | ||||
| from core.workflow.nodes.answer import AnswerStreamGenerateRoute | from core.workflow.nodes.answer import AnswerStreamGenerateRoute | ||||
| from core.workflow.nodes.end import EndStreamParam | from core.workflow.nodes.end import EndStreamParam | ||||
| from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions | |||||
| from core.workflow.nodes.llm.entities import ( | |||||
| ContextConfig, | |||||
| LLMNodeChatModelMessage, | |||||
| LLMNodeData, | |||||
| ModelConfig, | |||||
| VisionConfig, | |||||
| VisionConfigOptions, | |||||
| ) | |||||
| from core.workflow.nodes.llm.node import LLMNode | from core.workflow.nodes.llm.node import LLMNode | ||||
| from models.enums import UserFrom | from models.enums import UserFrom | ||||
| from models.provider import ProviderType | |||||
| from models.workflow import WorkflowType | from models.workflow import WorkflowType | ||||
| from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario | |||||
| class TestLLMNode: | |||||
| @pytest.fixture | |||||
| def llm_node(self): | |||||
| data = LLMNodeData( | |||||
| title="Test LLM", | |||||
| model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), | |||||
| prompt_template=[], | |||||
| memory=None, | |||||
| context=ContextConfig(enabled=False), | |||||
| vision=VisionConfig( | |||||
| enabled=True, | |||||
| configs=VisionConfigOptions( | |||||
| variable_selector=["sys", "files"], | |||||
| detail=ImagePromptMessageContent.DETAIL.HIGH, | |||||
| ), | |||||
| ), | |||||
| ) | |||||
| variable_pool = VariablePool( | |||||
| system_variables={}, | |||||
| user_inputs={}, | |||||
| ) | |||||
| node = LLMNode( | |||||
| id="1", | |||||
| config={ | |||||
| "id": "1", | |||||
| "data": data.model_dump(), | |||||
| }, | |||||
| graph_init_params=GraphInitParams( | |||||
| tenant_id="1", | |||||
| app_id="1", | |||||
| workflow_type=WorkflowType.WORKFLOW, | |||||
| workflow_id="1", | |||||
| graph_config={}, | |||||
| user_id="1", | |||||
| user_from=UserFrom.ACCOUNT, | |||||
| invoke_from=InvokeFrom.SERVICE_API, | |||||
| call_depth=0, | |||||
| class MockTokenBufferMemory: | |||||
| def __init__(self, history_messages=None): | |||||
| self.history_messages = history_messages or [] | |||||
| def get_history_prompt_messages( | |||||
| self, max_token_limit: int = 2000, message_limit: Optional[int] = None | |||||
| ) -> Sequence[PromptMessage]: | |||||
| if message_limit is not None: | |||||
| return self.history_messages[-message_limit * 2 :] | |||||
| return self.history_messages | |||||
| @pytest.fixture | |||||
| def llm_node(): | |||||
| data = LLMNodeData( | |||||
| title="Test LLM", | |||||
| model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), | |||||
| prompt_template=[], | |||||
| memory=None, | |||||
| context=ContextConfig(enabled=False), | |||||
| vision=VisionConfig( | |||||
| enabled=True, | |||||
| configs=VisionConfigOptions( | |||||
| variable_selector=["sys", "files"], | |||||
| detail=ImagePromptMessageContent.DETAIL.HIGH, | |||||
| ), | ), | ||||
| graph=Graph( | |||||
| root_node_id="1", | |||||
| answer_stream_generate_routes=AnswerStreamGenerateRoute( | |||||
| answer_dependencies={}, | |||||
| answer_generate_route={}, | |||||
| ), | |||||
| end_stream_param=EndStreamParam( | |||||
| end_dependencies={}, | |||||
| end_stream_variable_selector_mapping={}, | |||||
| ), | |||||
| ), | |||||
| ) | |||||
| variable_pool = VariablePool( | |||||
| system_variables={}, | |||||
| user_inputs={}, | |||||
| ) | |||||
| node = LLMNode( | |||||
| id="1", | |||||
| config={ | |||||
| "id": "1", | |||||
| "data": data.model_dump(), | |||||
| }, | |||||
| graph_init_params=GraphInitParams( | |||||
| tenant_id="1", | |||||
| app_id="1", | |||||
| workflow_type=WorkflowType.WORKFLOW, | |||||
| workflow_id="1", | |||||
| graph_config={}, | |||||
| user_id="1", | |||||
| user_from=UserFrom.ACCOUNT, | |||||
| invoke_from=InvokeFrom.SERVICE_API, | |||||
| call_depth=0, | |||||
| ), | |||||
| graph=Graph( | |||||
| root_node_id="1", | |||||
| answer_stream_generate_routes=AnswerStreamGenerateRoute( | |||||
| answer_dependencies={}, | |||||
| answer_generate_route={}, | |||||
| ), | ), | ||||
| graph_runtime_state=GraphRuntimeState( | |||||
| variable_pool=variable_pool, | |||||
| start_at=0, | |||||
| end_stream_param=EndStreamParam( | |||||
| end_dependencies={}, | |||||
| end_stream_variable_selector_mapping={}, | |||||
| ), | ), | ||||
| ) | |||||
| return node | |||||
| ), | |||||
| graph_runtime_state=GraphRuntimeState( | |||||
| variable_pool=variable_pool, | |||||
| start_at=0, | |||||
| ), | |||||
| ) | |||||
| return node | |||||
| @pytest.fixture | |||||
| def model_config(): | |||||
| # Create actual provider and model type instances | |||||
| model_provider_factory = ModelProviderFactory() | |||||
| provider_instance = model_provider_factory.get_provider_instance("openai") | |||||
| model_type_instance = provider_instance.get_model_instance(ModelType.LLM) | |||||
| # Create a ProviderModelBundle | |||||
| provider_model_bundle = ProviderModelBundle( | |||||
| configuration=ProviderConfiguration( | |||||
| tenant_id="1", | |||||
| provider=provider_instance.get_provider_schema(), | |||||
| preferred_provider_type=ProviderType.CUSTOM, | |||||
| using_provider_type=ProviderType.CUSTOM, | |||||
| system_configuration=SystemConfiguration(enabled=False), | |||||
| custom_configuration=CustomConfiguration(provider=None), | |||||
| model_settings=[], | |||||
| ), | |||||
| provider_instance=provider_instance, | |||||
| model_type_instance=model_type_instance, | |||||
| ) | |||||
| def test_fetch_files_with_file_segment(self, llm_node): | |||||
| file = File( | |||||
| # Create and return a ModelConfigWithCredentialsEntity | |||||
| return ModelConfigWithCredentialsEntity( | |||||
| provider="openai", | |||||
| model="gpt-3.5-turbo", | |||||
| model_schema=AIModelEntity( | |||||
| model="gpt-3.5-turbo", | |||||
| label=I18nObject(en_US="GPT-3.5 Turbo"), | |||||
| model_type=ModelType.LLM, | |||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||||
| model_properties={}, | |||||
| ), | |||||
| mode="chat", | |||||
| credentials={}, | |||||
| parameters={}, | |||||
| provider_model_bundle=provider_model_bundle, | |||||
| ) | |||||
| def test_fetch_files_with_file_segment(llm_node): | |||||
| file = File( | |||||
| id="1", | |||||
| tenant_id="test", | |||||
| type=FileType.IMAGE, | |||||
| filename="test.jpg", | |||||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||||
| related_id="1", | |||||
| ) | |||||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) | |||||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||||
| assert result == [file] | |||||
| def test_fetch_files_with_array_file_segment(llm_node): | |||||
| files = [ | |||||
| File( | |||||
| id="1", | id="1", | ||||
| tenant_id="test", | tenant_id="test", | ||||
| type=FileType.IMAGE, | type=FileType.IMAGE, | ||||
| filename="test.jpg", | |||||
| filename="test1.jpg", | |||||
| transfer_method=FileTransferMethod.LOCAL_FILE, | transfer_method=FileTransferMethod.LOCAL_FILE, | ||||
| related_id="1", | related_id="1", | ||||
| ), | |||||
| File( | |||||
| id="2", | |||||
| tenant_id="test", | |||||
| type=FileType.IMAGE, | |||||
| filename="test2.jpg", | |||||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||||
| related_id="2", | |||||
| ), | |||||
| ] | |||||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) | |||||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||||
| assert result == files | |||||
| def test_fetch_files_with_none_segment(llm_node): | |||||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) | |||||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||||
| assert result == [] | |||||
| def test_fetch_files_with_array_any_segment(llm_node): | |||||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) | |||||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||||
| assert result == [] | |||||
| def test_fetch_files_with_non_existent_variable(llm_node): | |||||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||||
| assert result == [] | |||||
| def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): | |||||
| prompt_template = [] | |||||
| llm_node.node_data.prompt_template = prompt_template | |||||
| fake_vision_detail = faker.random_element( | |||||
| [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] | |||||
| ) | |||||
| fake_remote_url = faker.url() | |||||
| files = [ | |||||
| File( | |||||
| id="1", | |||||
| tenant_id="test", | |||||
| type=FileType.IMAGE, | |||||
| filename="test1.jpg", | |||||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||||
| remote_url=fake_remote_url, | |||||
| ) | ) | ||||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) | |||||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||||
| assert result == [file] | |||||
| def test_fetch_files_with_array_file_segment(self, llm_node): | |||||
| files = [ | |||||
| File( | |||||
| id="1", | |||||
| tenant_id="test", | |||||
| type=FileType.IMAGE, | |||||
| filename="test1.jpg", | |||||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||||
| related_id="1", | |||||
| ), | |||||
| File( | |||||
| id="2", | |||||
| tenant_id="test", | |||||
| type=FileType.IMAGE, | |||||
| filename="test2.jpg", | |||||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||||
| related_id="2", | |||||
| ), | |||||
| ] | |||||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) | |||||
| ] | |||||
| fake_query = faker.sentence() | |||||
| prompt_messages, _ = llm_node._fetch_prompt_messages( | |||||
| user_query=fake_query, | |||||
| user_files=files, | |||||
| context=None, | |||||
| memory=None, | |||||
| model_config=model_config, | |||||
| prompt_template=prompt_template, | |||||
| memory_config=None, | |||||
| vision_enabled=False, | |||||
| vision_detail=fake_vision_detail, | |||||
| variable_pool=llm_node.graph_runtime_state.variable_pool, | |||||
| jinja2_variables=[], | |||||
| ) | |||||
| assert prompt_messages == [UserPromptMessage(content=fake_query)] | |||||
| def test_fetch_prompt_messages__basic(faker, llm_node, model_config): | |||||
| # Setup dify config | |||||
| dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" | |||||
| dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url" | |||||
| # Generate fake values for prompt template | |||||
| fake_assistant_prompt = faker.sentence() | |||||
| fake_query = faker.sentence() | |||||
| fake_context = faker.sentence() | |||||
| fake_window_size = faker.random_int(min=1, max=3) | |||||
| fake_vision_detail = faker.random_element( | |||||
| [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] | |||||
| ) | |||||
| fake_remote_url = faker.url() | |||||
| # Setup mock memory with history messages | |||||
| mock_history = [ | |||||
| UserPromptMessage(content=faker.sentence()), | |||||
| AssistantPromptMessage(content=faker.sentence()), | |||||
| UserPromptMessage(content=faker.sentence()), | |||||
| AssistantPromptMessage(content=faker.sentence()), | |||||
| UserPromptMessage(content=faker.sentence()), | |||||
| AssistantPromptMessage(content=faker.sentence()), | |||||
| ] | |||||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||||
| assert result == files | |||||
| # Setup memory configuration | |||||
| memory_config = MemoryConfig( | |||||
| role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), | |||||
| window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size), | |||||
| query_prompt_template=None, | |||||
| ) | |||||
| def test_fetch_files_with_none_segment(self, llm_node): | |||||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) | |||||
| memory = MockTokenBufferMemory(history_messages=mock_history) | |||||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||||
| assert result == [] | |||||
| # Test scenarios covering different file input combinations | |||||
| test_scenarios = [ | |||||
| LLMNodeTestScenario( | |||||
| description="No files", | |||||
| user_query=fake_query, | |||||
| user_files=[], | |||||
| features=[], | |||||
| vision_enabled=False, | |||||
| vision_detail=None, | |||||
| window_size=fake_window_size, | |||||
| prompt_template=[ | |||||
| LLMNodeChatModelMessage( | |||||
| text=fake_context, | |||||
| role=PromptMessageRole.SYSTEM, | |||||
| edition_type="basic", | |||||
| ), | |||||
| LLMNodeChatModelMessage( | |||||
| text="{#context#}", | |||||
| role=PromptMessageRole.USER, | |||||
| edition_type="basic", | |||||
| ), | |||||
| LLMNodeChatModelMessage( | |||||
| text=fake_assistant_prompt, | |||||
| role=PromptMessageRole.ASSISTANT, | |||||
| edition_type="basic", | |||||
| ), | |||||
| ], | |||||
| expected_messages=[ | |||||
| SystemPromptMessage(content=fake_context), | |||||
| UserPromptMessage(content=fake_context), | |||||
| AssistantPromptMessage(content=fake_assistant_prompt), | |||||
| ] | |||||
| + mock_history[fake_window_size * -2 :] | |||||
| + [ | |||||
| UserPromptMessage(content=fake_query), | |||||
| ], | |||||
| ), | |||||
| LLMNodeTestScenario( | |||||
| description="User files", | |||||
| user_query=fake_query, | |||||
| user_files=[ | |||||
| File( | |||||
| tenant_id="test", | |||||
| type=FileType.IMAGE, | |||||
| filename="test1.jpg", | |||||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||||
| remote_url=fake_remote_url, | |||||
| ) | |||||
| ], | |||||
| vision_enabled=True, | |||||
| vision_detail=fake_vision_detail, | |||||
| features=[ModelFeature.VISION], | |||||
| window_size=fake_window_size, | |||||
| prompt_template=[ | |||||
| LLMNodeChatModelMessage( | |||||
| text=fake_context, | |||||
| role=PromptMessageRole.SYSTEM, | |||||
| edition_type="basic", | |||||
| ), | |||||
| LLMNodeChatModelMessage( | |||||
| text="{#context#}", | |||||
| role=PromptMessageRole.USER, | |||||
| edition_type="basic", | |||||
| ), | |||||
| LLMNodeChatModelMessage( | |||||
| text=fake_assistant_prompt, | |||||
| role=PromptMessageRole.ASSISTANT, | |||||
| edition_type="basic", | |||||
| ), | |||||
| ], | |||||
| expected_messages=[ | |||||
| SystemPromptMessage(content=fake_context), | |||||
| UserPromptMessage(content=fake_context), | |||||
| AssistantPromptMessage(content=fake_assistant_prompt), | |||||
| ] | |||||
| + mock_history[fake_window_size * -2 :] | |||||
| + [ | |||||
| UserPromptMessage( | |||||
| content=[ | |||||
| TextPromptMessageContent(data=fake_query), | |||||
| ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), | |||||
| ] | |||||
| ), | |||||
| ], | |||||
| ), | |||||
| LLMNodeTestScenario( | |||||
| description="Prompt template with variable selector of File", | |||||
| user_query=fake_query, | |||||
| user_files=[], | |||||
| vision_enabled=False, | |||||
| vision_detail=fake_vision_detail, | |||||
| features=[ModelFeature.VISION], | |||||
| window_size=fake_window_size, | |||||
| prompt_template=[ | |||||
| LLMNodeChatModelMessage( | |||||
| text="{{#input.image#}}", | |||||
| role=PromptMessageRole.USER, | |||||
| edition_type="basic", | |||||
| ), | |||||
| ], | |||||
| expected_messages=[ | |||||
| UserPromptMessage( | |||||
| content=[ | |||||
| ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), | |||||
| ] | |||||
| ), | |||||
| ] | |||||
| + mock_history[fake_window_size * -2 :] | |||||
| + [UserPromptMessage(content=fake_query)], | |||||
| file_variables={ | |||||
| "input.image": File( | |||||
| tenant_id="test", | |||||
| type=FileType.IMAGE, | |||||
| filename="test1.jpg", | |||||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||||
| remote_url=fake_remote_url, | |||||
| ) | |||||
| }, | |||||
| ), | |||||
| LLMNodeTestScenario( | |||||
| description="Prompt template with variable selector of File without vision feature", | |||||
| user_query=fake_query, | |||||
| user_files=[], | |||||
| vision_enabled=True, | |||||
| vision_detail=fake_vision_detail, | |||||
| features=[], | |||||
| window_size=fake_window_size, | |||||
| prompt_template=[ | |||||
| LLMNodeChatModelMessage( | |||||
| text="{{#input.image#}}", | |||||
| role=PromptMessageRole.USER, | |||||
| edition_type="basic", | |||||
| ), | |||||
| ], | |||||
| expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)], | |||||
| file_variables={ | |||||
| "input.image": File( | |||||
| tenant_id="test", | |||||
| type=FileType.IMAGE, | |||||
| filename="test1.jpg", | |||||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||||
| remote_url=fake_remote_url, | |||||
| ) | |||||
| }, | |||||
| ), | |||||
| LLMNodeTestScenario( | |||||
| description="Prompt template with variable selector of File with video file and vision feature", | |||||
| user_query=fake_query, | |||||
| user_files=[], | |||||
| vision_enabled=True, | |||||
| vision_detail=fake_vision_detail, | |||||
| features=[ModelFeature.VISION], | |||||
| window_size=fake_window_size, | |||||
| prompt_template=[ | |||||
| LLMNodeChatModelMessage( | |||||
| text="{{#input.image#}}", | |||||
| role=PromptMessageRole.USER, | |||||
| edition_type="basic", | |||||
| ), | |||||
| ], | |||||
| expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)], | |||||
| file_variables={ | |||||
| "input.image": File( | |||||
| tenant_id="test", | |||||
| type=FileType.VIDEO, | |||||
| filename="test1.mp4", | |||||
| transfer_method=FileTransferMethod.REMOTE_URL, | |||||
| remote_url=fake_remote_url, | |||||
| extension="mp4", | |||||
| ) | |||||
| }, | |||||
| ), | |||||
| ] | |||||
| def test_fetch_files_with_array_any_segment(self, llm_node): | |||||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) | |||||
| for scenario in test_scenarios: | |||||
| model_config.model_schema.features = scenario.features | |||||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||||
| assert result == [] | |||||
| for k, v in scenario.file_variables.items(): | |||||
| selector = k.split(".") | |||||
| llm_node.graph_runtime_state.variable_pool.add(selector, v) | |||||
| # Call the method under test | |||||
| prompt_messages, _ = llm_node._fetch_prompt_messages( | |||||
| user_query=scenario.user_query, | |||||
| user_files=scenario.user_files, | |||||
| context=fake_context, | |||||
| memory=memory, | |||||
| model_config=model_config, | |||||
| prompt_template=scenario.prompt_template, | |||||
| memory_config=memory_config, | |||||
| vision_enabled=scenario.vision_enabled, | |||||
| vision_detail=scenario.vision_detail, | |||||
| variable_pool=llm_node.graph_runtime_state.variable_pool, | |||||
| jinja2_variables=[], | |||||
| ) | |||||
| def test_fetch_files_with_non_existent_variable(self, llm_node): | |||||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||||
| assert result == [] | |||||
| # Verify the result | |||||
| assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}" | |||||
| assert ( | |||||
| prompt_messages == scenario.expected_messages | |||||
| ), f"Message content mismatch in scenario: {scenario.description}" |
| from collections.abc import Mapping, Sequence | |||||
| from pydantic import BaseModel, Field | |||||
| from core.file import File | |||||
| from core.model_runtime.entities.message_entities import PromptMessage | |||||
| from core.model_runtime.entities.model_entities import ModelFeature | |||||
| from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage | |||||
| class LLMNodeTestScenario(BaseModel): | |||||
| """Test scenario for LLM node testing.""" | |||||
| description: str = Field(..., description="Description of the test scenario") | |||||
| user_query: str = Field(..., description="User query input") | |||||
| user_files: Sequence[File] = Field(default_factory=list, description="List of user files") | |||||
| vision_enabled: bool = Field(default=False, description="Whether vision is enabled") | |||||
| vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") | |||||
| features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") | |||||
| window_size: int = Field(..., description="Window size for memory") | |||||
| prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") | |||||
| file_variables: Mapping[str, File | Sequence[File]] = Field( | |||||
| default_factory=dict, description="List of file variables" | |||||
| ) | |||||
| expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") |
| hideSearch | hideSearch | ||||
| vars={availableVars} | vars={availableVars} | ||||
| onChange={handleSelectVar} | onChange={handleSelectVar} | ||||
| isSupportFileVar={false} | |||||
| /> | /> | ||||
| </div> | </div> | ||||
| )} | )} |
| isSupportConstantValue?: boolean | isSupportConstantValue?: boolean | ||||
| onlyLeafNodeVar?: boolean | onlyLeafNodeVar?: boolean | ||||
| filterVar?: (payload: Var, valueSelector: ValueSelector) => boolean | filterVar?: (payload: Var, valueSelector: ValueSelector) => boolean | ||||
| isSupportFileVar?: boolean | |||||
| } | } | ||||
| const VarList: FC<Props> = ({ | const VarList: FC<Props> = ({ | ||||
| isSupportConstantValue, | isSupportConstantValue, | ||||
| onlyLeafNodeVar, | onlyLeafNodeVar, | ||||
| filterVar, | filterVar, | ||||
| isSupportFileVar = true, | |||||
| }) => { | }) => { | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| defaultVarKindType={item.variable_type} | defaultVarKindType={item.variable_type} | ||||
| onlyLeafNodeVar={onlyLeafNodeVar} | onlyLeafNodeVar={onlyLeafNodeVar} | ||||
| filterVar={filterVar} | filterVar={filterVar} | ||||
| isSupportFileVar={isSupportFileVar} | |||||
| /> | /> | ||||
| {!readonly && ( | {!readonly && ( | ||||
| <RemoveButton | <RemoveButton |
| isInTable?: boolean | isInTable?: boolean | ||||
| onRemove?: () => void | onRemove?: () => void | ||||
| typePlaceHolder?: string | typePlaceHolder?: string | ||||
| isSupportFileVar?: boolean | |||||
| } | } | ||||
| const VarReferencePicker: FC<Props> = ({ | const VarReferencePicker: FC<Props> = ({ | ||||
| isInTable, | isInTable, | ||||
| onRemove, | onRemove, | ||||
| typePlaceHolder, | typePlaceHolder, | ||||
| isSupportFileVar = true, | |||||
| }) => { | }) => { | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const store = useStoreApi() | const store = useStoreApi() | ||||
| vars={outputVars} | vars={outputVars} | ||||
| onChange={handleVarReferenceChange} | onChange={handleVarReferenceChange} | ||||
| itemWidth={isAddBtnTrigger ? 260 : triggerWidth} | itemWidth={isAddBtnTrigger ? 260 : triggerWidth} | ||||
| isSupportFileVar={isSupportFileVar} | |||||
| /> | /> | ||||
| )} | )} | ||||
| </PortalToFollowElemContent> | </PortalToFollowElemContent> |
| vars: NodeOutPutVar[] | vars: NodeOutPutVar[] | ||||
| onChange: (value: ValueSelector, varDetail: Var) => void | onChange: (value: ValueSelector, varDetail: Var) => void | ||||
| itemWidth?: number | itemWidth?: number | ||||
| isSupportFileVar?: boolean | |||||
| } | } | ||||
| const VarReferencePopup: FC<Props> = ({ | const VarReferencePopup: FC<Props> = ({ | ||||
| vars, | vars, | ||||
| onChange, | onChange, | ||||
| itemWidth, | itemWidth, | ||||
| isSupportFileVar = true, | |||||
| }) => { | }) => { | ||||
| // max-h-[300px] overflow-y-auto todo: use portal to handle long list | // max-h-[300px] overflow-y-auto todo: use portal to handle long list | ||||
| return ( | return ( | ||||
| vars={vars} | vars={vars} | ||||
| onChange={onChange} | onChange={onChange} | ||||
| itemWidth={itemWidth} | itemWidth={itemWidth} | ||||
| isSupportFileVar | |||||
| isSupportFileVar={isSupportFileVar} | |||||
| /> | /> | ||||
| </div > | </div > | ||||
| ) | ) |
| list={inputs.variables} | list={inputs.variables} | ||||
| onChange={handleVarListChange} | onChange={handleVarListChange} | ||||
| filterVar={filterVar} | filterVar={filterVar} | ||||
| isSupportFileVar={false} | |||||
| /> | /> | ||||
| </Field> | </Field> | ||||
| <Split /> | <Split /> |
| onEditionTypeChange={onEditionTypeChange} | onEditionTypeChange={onEditionTypeChange} | ||||
| varList={varList} | varList={varList} | ||||
| handleAddVariable={handleAddVariable} | handleAddVariable={handleAddVariable} | ||||
| isSupportFileVar | |||||
| /> | /> | ||||
| ) | ) | ||||
| } | } |
| handleStop, | handleStop, | ||||
| varInputs, | varInputs, | ||||
| runResult, | runResult, | ||||
| filterJinjia2InputVar, | |||||
| } = useConfig(id, data) | } = useConfig(id, data) | ||||
| const model = inputs.model | const model = inputs.model | ||||
| list={inputs.prompt_config?.jinja2_variables || []} | list={inputs.prompt_config?.jinja2_variables || []} | ||||
| onChange={handleVarListChange} | onChange={handleVarListChange} | ||||
| onVarNameChange={handleVarNameChange} | onVarNameChange={handleVarNameChange} | ||||
| filterVar={filterVar} | |||||
| filterVar={filterJinjia2InputVar} | |||||
| isSupportFileVar={false} | |||||
| /> | /> | ||||
| </Field> | </Field> | ||||
| )} | )} | ||||
| hasSetBlockStatus={hasSetBlockStatus} | hasSetBlockStatus={hasSetBlockStatus} | ||||
| nodesOutputVars={availableVars} | nodesOutputVars={availableVars} | ||||
| availableNodes={availableNodesWithParent} | availableNodes={availableNodesWithParent} | ||||
| isSupportFileVar | |||||
| /> | /> | ||||
| {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && ( | {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && ( |
| }, [inputs, setInputs]) | }, [inputs, setInputs]) | ||||
| const filterInputVar = useCallback((varPayload: Var) => { | const filterInputVar = useCallback((varPayload: Var) => { | ||||
| return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type) | |||||
| }, []) | |||||
| const filterJinjia2InputVar = useCallback((varPayload: Var) => { | |||||
| return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) | return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) | ||||
| }, []) | }, []) | ||||
| const filterMemoryPromptVar = useCallback((varPayload: Var) => { | const filterMemoryPromptVar = useCallback((varPayload: Var) => { | ||||
| return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) | |||||
| return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type) | |||||
| }, []) | }, []) | ||||
| const { | const { | ||||
| handleRun, | handleRun, | ||||
| handleStop, | handleStop, | ||||
| runResult, | runResult, | ||||
| filterJinjia2InputVar, | |||||
| } | } | ||||
| } | } | ||||
| onChange={handleVarListChange} | onChange={handleVarListChange} | ||||
| onVarNameChange={handleVarNameChange} | onVarNameChange={handleVarNameChange} | ||||
| filterVar={filterVar} | filterVar={filterVar} | ||||
| isSupportFileVar={false} | |||||
| /> | /> | ||||
| </Field> | </Field> | ||||
| <Split /> | <Split /> |