Procházet zdrojové kódy

feat: Allow using file variables directly in the LLM node and support more file types. (#10679)

Co-authored-by: Joel <iamjoel007@gmail.com>
tags/0.12.0
-LAN- před 11 měsíci
rodič
revize
c5f7d650b5
Žádný účet není propojen s e-mailovou adresou tvůrce revize
36 změnil soubory, kde provedl 1036 přidání a 268 odebrání
  1. 0
    1
      api/configs/app_config.py
  2. 19
    23
      api/core/app/app_config/easy_ui_based_app/model_config/converter.py
  3. 4
    1
      api/core/app/task_pipeline/workflow_cycle_manage.py
  4. 25
    44
      api/core/file/file_manager.py
  5. 2
    1
      api/core/memory/token_buffer_memory.py
  6. 2
    2
      api/core/model_manager.py
  7. 5
    4
      api/core/model_runtime/callbacks/base_callback.py
  8. 2
    0
      api/core/model_runtime/entities/__init__.py
  9. 11
    2
      api/core/model_runtime/entities/message_entities.py
  10. 3
    0
      api/core/model_runtime/entities/model_entities.py
  11. 10
    10
      api/core/model_runtime/model_providers/__base/large_language_model.py
  12. 1
    0
      api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml
  13. 1
    0
      api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml
  14. 30
    6
      api/core/model_runtime/model_providers/anthropic/llm/llm.py
  15. 1
    0
      api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml
  16. 2
    1
      api/core/prompt/utils/prompt_message_util.py
  17. 10
    2
      api/core/variables/segments.py
  18. 16
    2
      api/core/workflow/nodes/llm/entities.py
  19. 8
    0
      api/core/workflow/nodes/llm/exc.py
  20. 357
    55
      api/core/workflow/nodes/llm/node.py
  21. 4
    2
      api/core/workflow/nodes/question_classifier/question_classifier_node.py
  22. 16
    1
      api/poetry.lock
  23. 1
    0
      api/pyproject.toml
  24. 0
    1
      api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py
  25. 3
    11
      api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py
  26. 455
    96
      api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
  27. 25
    0
      api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py
  28. 1
    0
      web/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx
  29. 3
    0
      web/app/components/workflow/nodes/_base/components/variable/var-list.tsx
  30. 3
    0
      web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx
  31. 3
    1
      web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx
  32. 1
    0
      web/app/components/workflow/nodes/code/panel.tsx
  33. 1
    0
      web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx
  34. 4
    1
      web/app/components/workflow/nodes/llm/panel.tsx
  35. 6
    1
      web/app/components/workflow/nodes/llm/use-config.ts
  36. 1
    0
      web/app/components/workflow/nodes/template-transform/panel.tsx

+ 0
- 1
api/configs/app_config.py Zobrazit soubor

# read from dotenv format config file # read from dotenv format config file
env_file=".env", env_file=".env",
env_file_encoding="utf-8", env_file_encoding="utf-8",
frozen=True,
# ignore extra attributes # ignore extra attributes
extra="ignore", extra="ignore",
) )

+ 19
- 23
api/core/app/app_config/easy_ui_based_app/model_config/converter.py Zobrazit soubor



class ModelConfigConverter: class ModelConfigConverter:
@classmethod @classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
""" """
Convert app model config dict to entity. Convert app model config dict to entity.
:param app_config: app config :param app_config: app config
) )


if model_credentials is None: if model_credentials is None:
if not skip_check:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
else:
model_credentials = {}

if not skip_check:
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
)

if provider_model is None:
model_name = model_config.model
raise ValueError(f"Model {model_name} not exist.")

if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")

# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
)

if provider_model is None:
model_name = model_config.model
raise ValueError(f"Model {model_name} not exist.")

if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")


# model config # model config
completion_params = model_config.parameters completion_params = model_config.parameters


model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)


if not skip_check and not model_schema:
if not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")


return ModelConfigWithCredentialsEntity( return ModelConfigWithCredentialsEntity(

+ 4
- 1
api/core/app/task_pipeline/workflow_cycle_manage.py Zobrazit soubor

).total_seconds() ).total_seconds()
db.session.commit() db.session.commit()


db.session.refresh(workflow_run)
db.session.close() db.session.close()


with Session(db.engine, expire_on_commit=False) as session:
session.add(workflow_run)
session.refresh(workflow_run)

if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
TraceTask( TraceTask(

+ 25
- 44
api/core/file/file_manager.py Zobrazit soubor

from configs import dify_config from configs import dify_config
from core.file import file_repository from core.file import file_repository
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
VideoPromptMessageContent,
)
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage


return file.remote_url return file.remote_url
case FileAttribute.EXTENSION: case FileAttribute.EXTENSION:
return file.extension return file.extension
case _:
raise ValueError(f"Invalid file attribute: {attr}")




def to_prompt_message_content( def to_prompt_message_content(
f: File, f: File,
/, /,
*, *,
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
): ):
"""
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.

This function takes a File object and converts it to an appropriate PromptMessageContent
object, which can be used as a prompt for image or audio-based AI models.

Args:
f (File): The File object to convert.
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.

Returns:
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level

Raises:
ValueError: If the file type is not supported or if required data is missing.
"""
match f.type: match f.type:
case FileType.IMAGE: case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f) data = _to_url(f)
else: else:


return ImagePromptMessageContent(data=data, detail=image_detail_config) return ImagePromptMessageContent(data=data, detail=image_detail_config)
case FileType.AUDIO: case FileType.AUDIO:
encoded_string = _file_to_encoded_string(f)
encoded_string = _get_encoded_string(f)
if f.extension is None: if f.extension is None:
raise ValueError("Missing file extension") raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
data = _to_url(f) data = _to_url(f)
else: else:
data = _to_base64_data_string(f) data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _get_encoded_string(f)
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
case _: case _:
raise ValueError("file type f.type is not supported")
raise ValueError(f"file type {f.type} is not supported")




def download(f: File, /): def download(f: File, /):
case FileTransferMethod.REMOTE_URL: case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status() response.raise_for_status()
content = response.content
encoded_string = base64.b64encode(content).decode("utf-8")
return encoded_string
data = response.content
case FileTransferMethod.LOCAL_FILE: case FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f) upload_file = file_repository.get_upload_file(session=db.session(), file=f)
data = _download_file_content(upload_file.key) data = _download_file_content(upload_file.key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case FileTransferMethod.TOOL_FILE: case FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f) tool_file = file_repository.get_tool_file(session=db.session(), file=f)
data = _download_file_content(tool_file.file_key) data = _download_file_content(tool_file.file_key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case _:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")

encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string




def _to_base64_data_string(f: File, /): def _to_base64_data_string(f: File, /):
return f"data:{f.mime_type};base64,{encoded_string}" return f"data:{f.mime_type};base64,{encoded_string}"




def _file_to_encoded_string(f: File, /):
match f.type:
case FileType.IMAGE:
return _to_base64_data_string(f)
case FileType.VIDEO:
return _to_base64_data_string(f)
case FileType.AUDIO:
return _get_encoded_string(f)
case _:
raise ValueError(f"file type {f.type} is not supported")


def _to_url(f: File, /): def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL: if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None: if f.remote_url is None:

+ 2
- 1
api/core/memory/token_buffer_memory.py Zobrazit soubor

from collections.abc import Sequence
from typing import Optional from typing import Optional


from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager


def get_history_prompt_messages( def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> list[PromptMessage]:
) -> Sequence[PromptMessage]:
""" """
Get history prompt messages. Get history prompt messages.
:param max_token_limit: max token limit :param max_token_limit: max token limit

+ 2
- 2
api/core/model_manager.py Zobrazit soubor



def invoke_llm( def invoke_llm(
self, self,
prompt_messages: list[PromptMessage],
prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None, model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None, tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,

+ 5
- 4
api/core/model_runtime/callbacks/base_callback.py Zobrazit soubor

from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Optional from typing import Optional


from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
): ):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:

+ 2
- 0
api/core/model_runtime/entities/__init__.py Zobrazit soubor

from .message_entities import ( from .message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
AudioPromptMessageContent, AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContent, PromptMessageContent,
"LLMResultChunk", "LLMResultChunk",
"LLMResultChunkDelta", "LLMResultChunkDelta",
"AudioPromptMessageContent", "AudioPromptMessageContent",
"DocumentPromptMessageContent",
] ]

+ 11
- 2
api/core/model_runtime/entities/message_entities.py Zobrazit soubor

from abc import ABC from abc import ABC
from collections.abc import Sequence
from enum import Enum from enum import Enum
from typing import Optional
from typing import Literal, Optional


from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator


IMAGE = "image" IMAGE = "image"
AUDIO = "audio" AUDIO = "audio"
VIDEO = "video" VIDEO = "video"
DOCUMENT = "document"




class PromptMessageContent(BaseModel): class PromptMessageContent(BaseModel):
detail: DETAIL = DETAIL.LOW detail: DETAIL = DETAIL.LOW




class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str


class PromptMessage(ABC, BaseModel): class PromptMessage(ABC, BaseModel):
""" """
Model class for prompt message. Model class for prompt message.
""" """


role: PromptMessageRole role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None
content: Optional[str | Sequence[PromptMessageContent]] = None
name: Optional[str] = None name: Optional[str] = None


def is_empty(self) -> bool: def is_empty(self) -> bool:

+ 3
- 0
api/core/model_runtime/entities/model_entities.py Zobrazit soubor

AGENT_THOUGHT = "agent-thought" AGENT_THOUGHT = "agent-thought"
VISION = "vision" VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call" STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"




class DefaultParameterName(str, Enum): class DefaultParameterName(str, Enum):

+ 10
- 10
api/core/model_runtime/model_providers/__base/large_language_model.py Zobrazit soubor

import re import re
import time import time
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator, Mapping
from collections.abc import Generator, Mapping, Sequence
from typing import Optional, Union from typing import Optional, Union


from pydantic import ConfigDict from pydantic import ConfigDict
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None, model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) )


model_parameters.pop("response_format") model_parameters.pop("response_format")
stop = stop or []
stop = list(stop) if stop is not None else []
stop.extend(["\n```", "```\n"]) stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block) block_prompts = block_prompts.replace("{{block}}", code_block)


prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,

+ 1
- 0
api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml Zobrazit soubor

- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 200000 context_size: 200000

+ 1
- 0
api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml Zobrazit soubor

- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 200000 context_size: 200000

+ 30
- 6
api/core/model_runtime/model_providers/anthropic/llm/llm.py Zobrazit soubor

import base64 import base64
import io import io
import json import json
from collections.abc import Generator
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast from typing import Optional, Union, cast


import anthropic import anthropic
from PIL import Image from PIL import Image


from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
DocumentPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
InvokeAuthorizationError, InvokeAuthorizationError,
InvokeBadRequestError, InvokeBadRequestError,
self, self,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage],
prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
# Add the new header for claude-3-5-sonnet-20240620 model # Add the new header for claude-3-5-sonnet-20240620 model
extra_headers = {} extra_headers = {}
if model == "claude-3-5-sonnet-20240620": if model == "claude-3-5-sonnet-20240620":
if model_parameters.get("max_tokens") > 4096:
if model_parameters.get("max_tokens", 0) > 4096:
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"


if any(
isinstance(content, DocumentPromptMessageContent)
for prompt_message in prompt_messages
if isinstance(prompt_message.content, list)
for content in prompt_message.content
):
extra_headers["anthropic-beta"] = "pdfs-2024-09-25"

if tools: if tools:
extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
response = client.beta.tools.messages.create( response = client.beta.tools.messages.create(
"source": {"type": "base64", "media_type": mime_type, "data": base64_data}, "source": {"type": "base64", "media_type": mime_type, "data": base64_data},
} }
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent):
if message_content.mime_type != "application/pdf":
raise ValueError(
f"Unsupported document type {message_content.mime_type}, "
"only support application/pdf"
)
sub_message_dict = {
"type": "document",
"source": {
"type": message_content.encode_format,
"media_type": message_content.mime_type,
"data": message_content.data,
},
}
sub_messages.append(sub_message_dict)
prompt_message_dicts.append({"role": "user", "content": sub_messages}) prompt_message_dicts.append({"role": "user", "content": sub_messages})
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)

+ 1
- 0
api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml Zobrazit soubor

- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call - stream-tool-call
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 128000

+ 2
- 1
api/core/prompt/utils/prompt_message_util.py Zobrazit soubor

from collections.abc import Sequence
from typing import cast from typing import cast


from core.model_runtime.entities import ( from core.model_runtime.entities import (


class PromptMessageUtil: class PromptMessageUtil:
@staticmethod @staticmethod
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]:
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]:
""" """
Prompt messages to prompt for saving. Prompt messages to prompt for saving.
:param model_mode: model mode :param model_mode: model mode

+ 10
- 2
api/core/variables/segments.py Zobrazit soubor



@property @property
def log(self) -> str: def log(self) -> str:
return str(self.value)
return ""


@property @property
def text(self) -> str: def text(self) -> str:
return str(self.value)
return ""




class ArrayAnySegment(ArraySegment): class ArrayAnySegment(ArraySegment):
for item in self.value: for item in self.value:
items.append(item.markdown) items.append(item.markdown)
return "\n".join(items) return "\n".join(items)

@property
def log(self) -> str:
return ""

@property
def text(self) -> str:
return ""

+ 16
- 2
api/core/workflow/nodes/llm/entities.py Zobrazit soubor





class PromptConfig(BaseModel): class PromptConfig(BaseModel):
jinja2_variables: Optional[list[VariableSelector]] = None
jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)

@field_validator("jinja2_variables", mode="before")
@classmethod
def convert_none_jinja2_variables(cls, v: Any):
if v is None:
return []
return v




class LLMNodeChatModelMessage(ChatModelMessage): class LLMNodeChatModelMessage(ChatModelMessage):
class LLMNodeData(BaseNodeData): class LLMNodeData(BaseNodeData):
model: ModelConfig model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: Optional[PromptConfig] = None
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: Optional[MemoryConfig] = None memory: Optional[MemoryConfig] = None
context: ContextConfig context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig) vision: VisionConfig = Field(default_factory=VisionConfig)

@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):
if v is None:
return PromptConfig()
return v

+ 8
- 0
api/core/workflow/nodes/llm/exc.py Zobrazit soubor



class NoPromptFoundError(LLMNodeError): class NoPromptFoundError(LLMNodeError):
"""Raised when no prompt is found in the LLM configuration.""" """Raised when no prompt is found in the LLM configuration."""


class NotSupportedPromptTypeError(LLMNodeError):
"""Raised when the prompt type is not supported."""


class MemoryRolePrefixRequiredError(LLMNodeError):
"""Raised when memory role prefix is required for completion model."""

+ 357
- 55
api/core/workflow/nodes/llm/node.py Zobrazit soubor

import json import json
import logging
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast


from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import ( from core.model_runtime.entities import (
AudioPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
TextPromptMessageContent, TextPromptMessageContent,
VideoPromptMessageContent,
) )
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables import ( from core.variables import (
ObjectSegment, ObjectSegment,
StringSegment, StringSegment,
) )
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
InvalidVariableTypeError, InvalidVariableTypeError,
LLMModeRequiredError, LLMModeRequiredError,
LLMNodeError, LLMNodeError,
MemoryRolePrefixRequiredError,
ModelNotExistError, ModelNotExistError,
NoPromptFoundError, NoPromptFoundError,
NotSupportedPromptTypeError,
VariableNotFoundError, VariableNotFoundError,
) )


if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.models import File from core.file.models import File


logger = logging.getLogger(__name__)



class LLMNode(BaseNode[LLMNodeData]): class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData _node_data_cls = LLMNodeData


# fetch prompt messages # fetch prompt messages
if self.node_data.memory: if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not query:
raise VariableNotFoundError("Query not found")
query = query.text
query = self.node_data.memory.query_prompt_template
else: else:
query = None query = None


prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
system_query=query,
inputs=inputs,
files=files,
user_query=query,
user_files=files,
context=context, context=context,
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
memory_config=self.node_data.memory, memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled, vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail, vision_detail=self.node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
) )


process_data = { process_data = {
) )
) )
return return
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run")
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data,
)
)
return


outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}


self, self,
node_data_model: ModelConfig, node_data_model: ModelConfig,
model_instance: ModelInstance, model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]: ) -> Generator[NodeEvent, None, None]:
db.session.close() db.session.close()


def _fetch_prompt_messages( def _fetch_prompt_messages(
self, self,
*, *,
system_query: str | None = None,
inputs: dict[str, str] | None = None,
files: Sequence["File"],
user_query: str | None = None,
user_files: Sequence["File"],
context: str | None = None, context: str | None = None,
memory: TokenBufferMemory | None = None, memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
memory_config: MemoryConfig | None = None, memory_config: MemoryConfig | None = None,
vision_enabled: bool = False, vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL, vision_detail: ImagePromptMessageContent.DETAIL,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
inputs = inputs or {}

prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs=inputs,
query=system_query or "",
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config,
)
stop = model_config.stop
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages = []

if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
_handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)

# Get memory messages for chat mode
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)

# Add current query to the prompt messages
if user_query:
message = LLMNodeChatModelMessage(
text=user_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
_handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)

elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model
prompt_messages.extend(
_handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
)

# Get memory text for completion model
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content

# Add current query to the prompt message
if user_query:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_messages[0].content = prompt_content
else:
errmsg = f"Prompt type {type(prompt_template)} is not supported"
logger.warning(errmsg)
raise NotSupportedPromptTypeError(errmsg)

if vision_enabled and user_files:
file_prompts = []
for file in user_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))

# Filter prompt messages
filtered_prompt_messages = [] filtered_prompt_messages = []
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
if prompt_message.is_empty():
continue

if not isinstance(prompt_message.content, str):
if isinstance(prompt_message.content, list):
prompt_message_content = [] prompt_message_content = []
for content_item in prompt_message.content or []:
# Skip image if vision is disabled
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
for content_item in prompt_message.content:
# Skip content if features are not defined
if not model_config.model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
continue continue


if isinstance(content_item, ImagePromptMessageContent):
# Override vision config if LLM node has vision config,
# cuz vision detail is related to the configuration from FileUpload feature.
content_item.detail = vision_detail
prompt_message_content.append(content_item)
elif isinstance(
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
# Skip content if corresponding feature is not supported
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_config.model_schema.features
)
): ):
prompt_message_content.append(content_item)

if len(prompt_message_content) > 1:
prompt_message.content = prompt_message_content
elif (
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
):
continue
prompt_message_content.append(content_item)
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content[0].data prompt_message.content = prompt_message_content[0].data

else:
prompt_message.content = prompt_message_content
if prompt_message.is_empty():
continue
filtered_prompt_messages.append(prompt_message) filtered_prompt_messages.append(prompt_message)


if not filtered_prompt_messages:
if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError( raise NoPromptFoundError(
"No prompt found in the LLM configuration. " "No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding." "Please ensure a prompt is properly configured before proceeding."
) )


stop = model_config.stop
return filtered_prompt_messages, stop return filtered_prompt_messages, stop


@classmethod @classmethod
} }
}, },
} }


def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
raise NotImplementedError(f"Role {role} is not supported")


def _render_jinja2_message(
*,
template: str,
jinjia2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""

jinjia2_inputs = {}
for jinja2_variable in jinjia2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinjia2_inputs,
)
result_text = code_execute_resp["result"]
return result_text


def _handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)

# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)

# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
prompt_messages.append(prompt_message)

if file_contents:
# Create message with image contents
prompt_message = UserPromptMessage(content=file_contents)
prompt_messages.append(prompt_message)

return prompt_messages


def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
rest_tokens = 2000

model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)

curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)

max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
or 0
)

rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)

return rest_tokens


def _handle_memory_chat_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages


def _handle_memory_completion_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text


def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:
"""Handle completion template processing outside of LLMNode class.

Args:
template: The completion model prompt template
context: Optional context string
jinja2_variables: Variables for jinja2 template rendering
variable_pool: Variable pool for template conversion

Returns:
Sequence of prompt messages
"""
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:
if context:
template_text = template.text.replace("{#context#}", context)
else:
template_text = template.text
result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
prompt_messages.append(prompt_message)
return prompt_messages

+ 4
- 2
api/core/workflow/nodes/question_classifier/question_classifier_node.py Zobrazit soubor

) )
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
system_query=query,
user_query=query,
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
files=files,
user_files=files,
vision_enabled=node_data.vision.enabled, vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail, vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
) )


# handle invoke result # handle invoke result

+ 16
- 1
api/poetry.lock Zobrazit soubor

[package.extras] [package.extras]
test = ["pytest (>=6)"] test = ["pytest (>=6)"]


[[package]]
name = "faker"
version = "32.1.0"
description = "Faker is a Python package that generates fake data for you."
optional = false
python-versions = ">=3.8"
files = [
{file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"},
{file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"},
]

[package.dependencies]
python-dateutil = ">=2.4"
typing-extensions = "*"

[[package]] [[package]]
name = "fal-client" name = "fal-client"
version = "0.5.6" version = "0.5.6"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "69a3f471f85dce9e5fb889f739e148a4a6d95aaf94081414503867c7157dba69"
content-hash = "d149b24ce7a203fa93eddbe8430d8ea7e5160a89c8d348b1b747c19899065639"

+ 1
- 0
api/pyproject.toml Zobrazit soubor

optional = true optional = true
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
coverage = "~7.2.4" coverage = "~7.2.4"
faker = "~32.1.0"
pytest = "~8.3.2" pytest = "~8.3.2"
pytest-benchmark = "~4.0.0" pytest-benchmark = "~4.0.0"
pytest-env = "~1.1.3" pytest-env = "~1.1.3"

+ 0
- 1
api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py Zobrazit soubor

) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock




@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) @pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)

+ 3
- 11
api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py Zobrazit soubor



from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel




def test_validate_credentials(): def test_validate_credentials():
model = AzureAIStudioRerankModel()
model = AzureRerankModel()


with pytest.raises(CredentialsValidateFailedError): with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials( model.validate_credentials(
model="azure-ai-studio-rerank-v1", model="azure-ai-studio-rerank-v1",
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
) )




def test_invoke_model(): def test_invoke_model():
model = AzureAIStudioRerankModel()
model = AzureRerankModel()


result = model.invoke( result = model.invoke(
model="azure-ai-studio-rerank-v1", model="azure-ai-studio-rerank-v1",

+ 455
- 96
api/tests/unit_tests/core/workflow/nodes/llm/test_node.py Zobrazit soubor

from collections.abc import Sequence
from typing import Optional

import pytest import pytest


from core.app.entities.app_invoke_entities import InvokeFrom
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.file import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
VisionConfigOptions,
)
from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom from models.enums import UserFrom
from models.provider import ProviderType
from models.workflow import WorkflowType from models.workflow import WorkflowType
from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario




class TestLLMNode:
@pytest.fixture
def llm_node(self):
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
memory=None,
context=ContextConfig(enabled=False),
vision=VisionConfig(
enabled=True,
configs=VisionConfigOptions(
variable_selector=["sys", "files"],
detail=ImagePromptMessageContent.DETAIL.HIGH,
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
class MockTokenBufferMemory:
def __init__(self, history_messages=None):
self.history_messages = history_messages or []

def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]:
if message_limit is not None:
return self.history_messages[-message_limit * 2 :]
return self.history_messages


@pytest.fixture
def llm_node():
data = LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
memory=None,
context=ContextConfig(enabled=False),
vision=VisionConfig(
enabled=True,
configs=VisionConfigOptions(
variable_selector=["sys", "files"],
detail=ImagePromptMessageContent.DETAIL.HIGH,
), ),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
), ),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
), ),
)
return node
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node


@pytest.fixture
def model_config():
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory()
provider_instance = model_provider_factory.get_provider_instance("openai")
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)

# Create a ProviderModelBundle
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",
provider=provider_instance.get_provider_schema(),
preferred_provider_type=ProviderType.CUSTOM,
using_provider_type=ProviderType.CUSTOM,
system_configuration=SystemConfiguration(enabled=False),
custom_configuration=CustomConfiguration(provider=None),
model_settings=[],
),
provider_instance=provider_instance,
model_type_instance=model_type_instance,
)


def test_fetch_files_with_file_segment(self, llm_node):
file = File(
# Create and return a ModelConfigWithCredentialsEntity
return ModelConfigWithCredentialsEntity(
provider="openai",
model="gpt-3.5-turbo",
model_schema=AIModelEntity(
model="gpt-3.5-turbo",
label=I18nObject(en_US="GPT-3.5 Turbo"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={},
),
mode="chat",
credentials={},
parameters={},
provider_model_bundle=provider_model_bundle,
)


def test_fetch_files_with_file_segment(llm_node):
file = File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)

result = llm_node._fetch_files(selector=["sys", "files"])
assert result == [file]


def test_fetch_files_with_array_file_segment(llm_node):
files = [
File(
id="1", id="1",
tenant_id="test", tenant_id="test",
type=FileType.IMAGE, type=FileType.IMAGE,
filename="test.jpg",
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE, transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1", related_id="1",
),
File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))

result = llm_node._fetch_files(selector=["sys", "files"])
assert result == files


def test_fetch_files_with_none_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())

result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []


def test_fetch_files_with_array_any_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))

result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []


def test_fetch_files_with_non_existent_variable(llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []


def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
prompt_template = []
llm_node.node_data.prompt_template = prompt_template

fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
) )
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)

result = llm_node._fetch_files(selector=["sys", "files"])
assert result == [file]

def test_fetch_files_with_array_file_segment(self, llm_node):
files = [
File(
id="1",
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
),
File(
id="2",
tenant_id="test",
type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
]

fake_query = faker.sentence()

prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=fake_query,
user_files=files,
context=None,
memory=None,
model_config=model_config,
prompt_template=prompt_template,
memory_config=None,
vision_enabled=False,
vision_detail=fake_vision_detail,
variable_pool=llm_node.graph_runtime_state.variable_pool,
jinja2_variables=[],
)

assert prompt_messages == [UserPromptMessage(content=fake_query)]


def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Setup dify config
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"

# Generate fake values for prompt template
fake_assistant_prompt = faker.sentence()
fake_query = faker.sentence()
fake_context = faker.sentence()
fake_window_size = faker.random_int(min=1, max=3)
fake_vision_detail = faker.random_element(
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
)
fake_remote_url = faker.url()

# Setup mock memory with history messages
mock_history = [
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
UserPromptMessage(content=faker.sentence()),
AssistantPromptMessage(content=faker.sentence()),
]


result = llm_node._fetch_files(selector=["sys", "files"])
assert result == files
# Setup memory configuration
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
query_prompt_template=None,
)


def test_fetch_files_with_none_segment(self, llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
memory = MockTokenBufferMemory(history_messages=mock_history)


result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
# Test scenarios covering different file input combinations
test_scenarios = [
LLMNodeTestScenario(
description="No files",
user_query=fake_query,
user_files=[],
features=[],
vision_enabled=False,
vision_detail=None,
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text=fake_context,
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
],
expected_messages=[
SystemPromptMessage(content=fake_context),
UserPromptMessage(content=fake_context),
AssistantPromptMessage(content=fake_assistant_prompt),
]
+ mock_history[fake_window_size * -2 :]
+ [
UserPromptMessage(content=fake_query),
],
),
LLMNodeTestScenario(
description="User files",
user_query=fake_query,
user_files=[
File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
],
vision_enabled=True,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text=fake_context,
role=PromptMessageRole.SYSTEM,
edition_type="basic",
),
LLMNodeChatModelMessage(
text="{#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
),
LLMNodeChatModelMessage(
text=fake_assistant_prompt,
role=PromptMessageRole.ASSISTANT,
edition_type="basic",
),
],
expected_messages=[
SystemPromptMessage(content=fake_context),
UserPromptMessage(content=fake_context),
AssistantPromptMessage(content=fake_assistant_prompt),
]
+ mock_history[fake_window_size * -2 :]
+ [
UserPromptMessage(
content=[
TextPromptMessageContent(data=fake_query),
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
),
],
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File",
user_query=fake_query,
user_files=[],
vision_enabled=False,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
],
expected_messages=[
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
),
]
+ mock_history[fake_window_size * -2 :]
+ [UserPromptMessage(content=fake_query)],
file_variables={
"input.image": File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
},
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File without vision feature",
user_query=fake_query,
user_files=[],
vision_enabled=True,
vision_detail=fake_vision_detail,
features=[],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
],
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
file_variables={
"input.image": File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
},
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File with video file and vision feature",
user_query=fake_query,
user_files=[],
vision_enabled=True,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
],
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
file_variables={
"input.image": File(
tenant_id="test",
type=FileType.VIDEO,
filename="test1.mp4",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
extension="mp4",
)
},
),
]


def test_fetch_files_with_array_any_segment(self, llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
for scenario in test_scenarios:
model_config.model_schema.features = scenario.features


result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
for k, v in scenario.file_variables.items():
selector = k.split(".")
llm_node.graph_runtime_state.variable_pool.add(selector, v)

# Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=scenario.user_query,
user_files=scenario.user_files,
context=fake_context,
memory=memory,
model_config=model_config,
prompt_template=scenario.prompt_template,
memory_config=memory_config,
vision_enabled=scenario.vision_enabled,
vision_detail=scenario.vision_detail,
variable_pool=llm_node.graph_runtime_state.variable_pool,
jinja2_variables=[],
)


def test_fetch_files_with_non_existent_variable(self, llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
assert result == []
# Verify the result
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
assert (
prompt_messages == scenario.expected_messages
), f"Message content mismatch in scenario: {scenario.description}"

+ 25
- 0
api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py Zobrazit soubor

from collections.abc import Mapping, Sequence

from pydantic import BaseModel, Field

from core.file import File
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelFeature
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage


class LLMNodeTestScenario(BaseModel):
"""Test scenario for LLM node testing."""

description: str = Field(..., description="Description of the test scenario")
user_query: str = Field(..., description="User query input")
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
window_size: int = Field(..., description="Window size for memory")
prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
file_variables: Mapping[str, File | Sequence[File]] = Field(
default_factory=dict, description="List of file variables"
)
expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")

+ 1
- 0
web/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx Zobrazit soubor

hideSearch hideSearch
vars={availableVars} vars={availableVars}
onChange={handleSelectVar} onChange={handleSelectVar}
isSupportFileVar={false}
/> />
</div> </div>
)} )}

+ 3
- 0
web/app/components/workflow/nodes/_base/components/variable/var-list.tsx Zobrazit soubor

isSupportConstantValue?: boolean isSupportConstantValue?: boolean
onlyLeafNodeVar?: boolean onlyLeafNodeVar?: boolean
filterVar?: (payload: Var, valueSelector: ValueSelector) => boolean filterVar?: (payload: Var, valueSelector: ValueSelector) => boolean
isSupportFileVar?: boolean
} }


const VarList: FC<Props> = ({ const VarList: FC<Props> = ({
isSupportConstantValue, isSupportConstantValue,
onlyLeafNodeVar, onlyLeafNodeVar,
filterVar, filterVar,
isSupportFileVar = true,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()


defaultVarKindType={item.variable_type} defaultVarKindType={item.variable_type}
onlyLeafNodeVar={onlyLeafNodeVar} onlyLeafNodeVar={onlyLeafNodeVar}
filterVar={filterVar} filterVar={filterVar}
isSupportFileVar={isSupportFileVar}
/> />
{!readonly && ( {!readonly && (
<RemoveButton <RemoveButton

+ 3
- 0
web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx Zobrazit soubor

isInTable?: boolean isInTable?: boolean
onRemove?: () => void onRemove?: () => void
typePlaceHolder?: string typePlaceHolder?: string
isSupportFileVar?: boolean
} }


const VarReferencePicker: FC<Props> = ({ const VarReferencePicker: FC<Props> = ({
isInTable, isInTable,
onRemove, onRemove,
typePlaceHolder, typePlaceHolder,
isSupportFileVar = true,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const store = useStoreApi() const store = useStoreApi()
vars={outputVars} vars={outputVars}
onChange={handleVarReferenceChange} onChange={handleVarReferenceChange}
itemWidth={isAddBtnTrigger ? 260 : triggerWidth} itemWidth={isAddBtnTrigger ? 260 : triggerWidth}
isSupportFileVar={isSupportFileVar}
/> />
)} )}
</PortalToFollowElemContent> </PortalToFollowElemContent>

+ 3
- 1
web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx Zobrazit soubor

vars: NodeOutPutVar[] vars: NodeOutPutVar[]
onChange: (value: ValueSelector, varDetail: Var) => void onChange: (value: ValueSelector, varDetail: Var) => void
itemWidth?: number itemWidth?: number
isSupportFileVar?: boolean
} }
const VarReferencePopup: FC<Props> = ({ const VarReferencePopup: FC<Props> = ({
vars, vars,
onChange, onChange,
itemWidth, itemWidth,
isSupportFileVar = true,
}) => { }) => {
// max-h-[300px] overflow-y-auto todo: use portal to handle long list // max-h-[300px] overflow-y-auto todo: use portal to handle long list
return ( return (
vars={vars} vars={vars}
onChange={onChange} onChange={onChange}
itemWidth={itemWidth} itemWidth={itemWidth}
isSupportFileVar
isSupportFileVar={isSupportFileVar}
/> />
</div > </div >
) )

+ 1
- 0
web/app/components/workflow/nodes/code/panel.tsx Zobrazit soubor

list={inputs.variables} list={inputs.variables}
onChange={handleVarListChange} onChange={handleVarListChange}
filterVar={filterVar} filterVar={filterVar}
isSupportFileVar={false}
/> />
</Field> </Field>
<Split /> <Split />

+ 1
- 0
web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx Zobrazit soubor

onEditionTypeChange={onEditionTypeChange} onEditionTypeChange={onEditionTypeChange}
varList={varList} varList={varList}
handleAddVariable={handleAddVariable} handleAddVariable={handleAddVariable}
isSupportFileVar
/> />
) )
} }

+ 4
- 1
web/app/components/workflow/nodes/llm/panel.tsx Zobrazit soubor

handleStop, handleStop,
varInputs, varInputs,
runResult, runResult,
filterJinjia2InputVar,
} = useConfig(id, data) } = useConfig(id, data)


const model = inputs.model const model = inputs.model
list={inputs.prompt_config?.jinja2_variables || []} list={inputs.prompt_config?.jinja2_variables || []}
onChange={handleVarListChange} onChange={handleVarListChange}
onVarNameChange={handleVarNameChange} onVarNameChange={handleVarNameChange}
filterVar={filterVar}
filterVar={filterJinjia2InputVar}
isSupportFileVar={false}
/> />
</Field> </Field>
)} )}
hasSetBlockStatus={hasSetBlockStatus} hasSetBlockStatus={hasSetBlockStatus}
nodesOutputVars={availableVars} nodesOutputVars={availableVars}
availableNodes={availableNodesWithParent} availableNodes={availableNodesWithParent}
isSupportFileVar
/> />


{inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && ( {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (

+ 6
- 1
web/app/components/workflow/nodes/llm/use-config.ts Zobrazit soubor

}, [inputs, setInputs]) }, [inputs, setInputs])


const filterInputVar = useCallback((varPayload: Var) => { const filterInputVar = useCallback((varPayload: Var) => {
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type)
}, [])

const filterJinjia2InputVar = useCallback((varPayload: Var) => {
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
}, []) }, [])


const filterMemoryPromptVar = useCallback((varPayload: Var) => { const filterMemoryPromptVar = useCallback((varPayload: Var) => {
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type)
}, []) }, [])


const { const {
handleRun, handleRun,
handleStop, handleStop,
runResult, runResult,
filterJinjia2InputVar,
} }
} }



+ 1
- 0
web/app/components/workflow/nodes/template-transform/panel.tsx Zobrazit soubor

onChange={handleVarListChange} onChange={handleVarListChange}
onVarNameChange={handleVarNameChange} onVarNameChange={handleVarNameChange}
filterVar={filterVar} filterVar={filterVar}
isSupportFileVar={false}
/> />
</Field> </Field>
<Split /> <Split />

Načítá se…
Zrušit
Uložit