Sfoglia il codice sorgente

chore: the consistency of MultiModalPromptMessageContent (#11721)

tags/0.14.1
非法操作 10 mesi fa
parent
commit
c9b4029ce7
Nessun account collegato all'indirizzo email del committer

+ 1
- 2
api/.env.example Vedi File

@@ -313,8 +313,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50

# Model configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024


+ 4
- 9
api/configs/feature/__init__.py Vedi File

@@ -665,14 +665,9 @@ class IndexingConfig(BaseSettings):
)


class VisionFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)

MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)

@@ -778,13 +773,13 @@ class FeatureConfig(
FileAccessConfig,
FileUploadConfig,
HttpConfig,
VisionFormatConfig,
InnerAPIConfig,
IndexingConfig,
LoggingConfig,
MailConfig,
ModelLoadBalanceConfig,
ModerationConfig,
MultiModalTransferConfig,
PositionConfig,
RagEtlConfig,
SecurityConfig,

+ 25
- 32
api/core/file/file_manager.py Vedi File

@@ -42,33 +42,31 @@ def to_prompt_message_content(
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
):
match f.type:
case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)

return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
case FileType.AUDIO:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _to_base64_data_string(f)
return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
case _:
raise ValueError(f"file type {f.type} is not supported")
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
raise ValueError("Missing file mime_type")

params = {
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW

prompt_class_map = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
FileType.DOCUMENT: DocumentPromptMessageContent,
}

try:
return prompt_class_map[f.type](**params)
except KeyError:
raise ValueError(f"file type {f.type} is not supported")


def download(f: File, /):
@@ -122,11 +120,6 @@ def _get_encoded_string(f: File, /):
return encoded_string


def _to_base64_data_string(f: File, /):
encoded_string = _get_encoded_string(f)
return f"data:{f.mime_type};base64,{encoded_string}"


def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:

+ 24
- 15
api/core/model_runtime/entities/message_entities.py Vedi File

@@ -1,9 +1,9 @@
from abc import ABC
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Literal, Optional
from typing import Optional

from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, computed_field, field_validator


class PromptMessageRole(Enum):
@@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel):
"""

type: PromptMessageContentType
data: str


class TextPromptMessageContent(PromptMessageContent):
@@ -76,21 +75,35 @@ class TextPromptMessageContent(PromptMessageContent):
"""

type: PromptMessageContentType = PromptMessageContentType.TEXT
data: str


class MultiModalPromptMessageContent(PromptMessageContent):
"""
Model class for multi-modal prompt message content.
"""

type: PromptMessageContentType
format: str = Field(..., description="the format of multi-modal file")
base64_data: str = Field("", description="the base64 data of multi-modal file")
url: str = Field("", description="the url of multi-modal file")
mime_type: str = Field(..., description="the mime type of multi-modal file")

@computed_field(return_type=str)
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"

class VideoPromptMessageContent(PromptMessageContent):

class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO
data: str = Field(..., description="Base64 encoded video data")
format: str = Field(..., description="Video format")


class AudioPromptMessageContent(PromptMessageContent):
class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data")
format: str = Field(..., description="Audio format")


class ImagePromptMessageContent(PromptMessageContent):
class ImagePromptMessageContent(MultiModalPromptMessageContent):
"""
Model class for image prompt message content.
"""
@@ -101,14 +114,10 @@ class ImagePromptMessageContent(PromptMessageContent):

type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
format: str = Field("jpg", description="Image format")


class DocumentPromptMessageContent(PromptMessageContent):
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
data: str
format: str = Field(..., description="Document format")


class PromptMessage(ABC, BaseModel):

+ 10
- 17
api/core/model_runtime/model_providers/anthropic/llm/llm.py Vedi File

@@ -1,5 +1,4 @@
import base64
import io
import json
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
@@ -18,7 +17,6 @@ from anthropic.types import (
)
from anthropic.types.beta.tools import ToolsBetaMessage
from httpx import Timeout
from PIL import Image

from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities import (
@@ -498,22 +496,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
if not message_content.base64_data:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
image_content = requests.get(message_content.url).content
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(
f"Failed to fetch image data from url {message_content.data}, {ex}"
)
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
base64_data = message_content.base64_data

mime_type = message_content.mime_type
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
@@ -526,19 +521,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent):
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type != "application/pdf":
if message_content.mime_type != "application/pdf":
raise ValueError(
f"Unsupported document type {mime_type}, " "only support application/pdf"
f"Unsupported document type {message_content.mime_type}, "
"only support application/pdf"
)
sub_message_dict = {
"type": "document",
"source": {
"type": message_content.encode_format,
"media_type": mime_type,
"data": base64_data,
"type": "base64",
"media_type": message_content.mime_type,
"data": message_content.data,
},
}
sub_messages.append(sub_message_dict)

+ 3
- 3
api/core/model_runtime/model_providers/tongyi/llm/llm.py Vedi File

@@ -434,9 +434,9 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.VIDEO:
message_content = cast(VideoPromptMessageContent, message_content)
video_url = message_content.data
if message_content.data.startswith("data:"):
raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")
video_url = message_content.url
if not video_url:
raise InvokeError("not support base64, please set MULTIMODAL_SEND_FORMAT to url")

sub_message_dict = {"video": video_url}
sub_messages.append(sub_message_dict)

+ 3
- 1
api/tests/integration_tests/model_runtime/azure_openai/test_llm.py
File diff soppresso perché troppo grande
Vedi File


+ 9
- 3
api/tests/integration_tests/model_runtime/google/test_llm.py
File diff soppresso perché troppo grande
Vedi File


+ 6
- 2
api/tests/integration_tests/model_runtime/ollama/test_llm.py
File diff soppresso perché troppo grande
Vedi File


+ 3
- 1
api/tests/integration_tests/model_runtime/openai/test_llm.py
File diff soppresso perché troppo grande
Vedi File


+ 5
- 1
api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py Vedi File

@@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch

import pytest

from configs import dify_config
from core.app.app_config.entities import ModelConfigEntity
from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -126,6 +127,7 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):

def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
model_config_mock, _, messages, inputs, context = get_chat_model_args
dify_config.MULTIMODAL_SEND_FORMAT = "url"

files = [
File(
@@ -140,7 +142,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
mock_get_encoded_string.return_value = ImagePromptMessageContent(
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,

+ 12
- 8
api/tests/unit_tests/core/workflow/nodes/llm/test_node.py Vedi File

@@ -18,8 +18,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
@@ -249,8 +248,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):

def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Setup dify config
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
dify_config.MULTIMODAL_SEND_FORMAT = "url"

# Generate fake values for prompt template
fake_assistant_prompt = faker.sentence()
@@ -326,9 +324,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
extension=".jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
)
],
vision_enabled=True,
@@ -362,7 +361,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
ImagePromptMessageContent(
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
),
]
),
],
@@ -385,7 +386,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
expected_messages=[
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
ImagePromptMessageContent(
url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail
),
]
),
]
@@ -396,9 +399,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
extension=".jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
)
},
),

+ 3
- 4
docker/.env.example Vedi File

@@ -614,13 +614,12 @@ CODE_GENERATION_MAX_TOKENS=1024
# Multi-modal Configuration
# ------------------------------

# The format of the image/video sent when the multi-modal model is input,
# The format of the image/video/audio/document sent when the multi-modal model is input,
# the default is base64, optional url.
# The delay of the call in url mode will be lower than that in base64 mode.
# It is generally recommended to use the more compatible base64 mode.
# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video.
MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video/audio/document.
MULTIMODAL_SEND_FORMAT=base64

# Upload image file size limit, default 10M.
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10

+ 1
- 2
docker/docker-compose.yaml Vedi File

@@ -225,8 +225,7 @@ x-shared-env: &shared-api-worker-env
UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-}
PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512}
CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024}
MULTIMODAL_SEND_IMAGE_FORMAT: ${MULTIMODAL_SEND_IMAGE_FORMAT:-base64}
MULTIMODAL_SEND_VIDEO_FORMAT: ${MULTIMODAL_SEND_VIDEO_FORMAT:-base64}
MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64}
UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10}
UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100}
UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50}

Loading…
Annulla
Salva