Co-authored-by: 朱庆超 <zhuqingchao@xiaomi.com> Co-authored-by: crazywoola <427733928@qq.com>tags/1.3.0
| @@ -21,14 +21,13 @@ from core.model_runtime.entities import ( | |||
| AssistantPromptMessage, | |||
| LLMUsage, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageTool, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| ToolPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes | |||
| from core.model_runtime.entities.model_entities import ModelFeature | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.prompt.utils.extract_thread_messages import extract_thread_messages | |||
| @@ -501,7 +500,7 @@ class BaseAgentRunner(AppRunner): | |||
| ) | |||
| if not file_objs: | |||
| return UserPromptMessage(content=message.query) | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=message.query)) | |||
| for file in file_objs: | |||
| prompt_message_contents.append( | |||
| @@ -5,12 +5,11 @@ from core.file import file_manager | |||
| from core.model_runtime.entities import ( | |||
| AssistantPromptMessage, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| @@ -40,7 +39,7 @@ class CotChatAgentRunner(CotAgentRunner): | |||
| Organize user query | |||
| """ | |||
| if self.files: | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||
| # get image detail config | |||
| @@ -15,14 +15,13 @@ from core.model_runtime.entities import ( | |||
| LLMResultChunkDelta, | |||
| LLMUsage, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageContentType, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| ToolPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes | |||
| from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | |||
| from core.tools.entities.tool_entities import ToolInvokeMeta | |||
| from core.tools.tool_engine import ToolEngine | |||
| @@ -395,7 +394,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| Organize user query | |||
| """ | |||
| if self.files: | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||
| # get image detail config | |||
| @@ -7,9 +7,9 @@ from core.model_runtime.entities import ( | |||
| AudioPromptMessageContent, | |||
| DocumentPromptMessageContent, | |||
| ImagePromptMessageContent, | |||
| MultiModalPromptMessageContent, | |||
| VideoPromptMessageContent, | |||
| ) | |||
| from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes | |||
| from extensions.ext_storage import storage | |||
| from . import helpers | |||
| @@ -43,7 +43,7 @@ def to_prompt_message_content( | |||
| /, | |||
| *, | |||
| image_detail_config: ImagePromptMessageContent.DETAIL | None = None, | |||
| ) -> MultiModalPromptMessageContent: | |||
| ) -> PromptMessageContentUnionTypes: | |||
| if f.extension is None: | |||
| raise ValueError("Missing file extension") | |||
| if f.mime_type is None: | |||
| @@ -58,7 +58,7 @@ def to_prompt_message_content( | |||
| if f.type == FileType.IMAGE: | |||
| params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW | |||
| prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = { | |||
| prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { | |||
| FileType.IMAGE: ImagePromptMessageContent, | |||
| FileType.AUDIO: AudioPromptMessageContent, | |||
| FileType.VIDEO: VideoPromptMessageContent, | |||
| @@ -8,11 +8,11 @@ from core.model_runtime.entities import ( | |||
| AssistantPromptMessage, | |||
| ImagePromptMessageContent, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageRole, | |||
| TextPromptMessageContent, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes | |||
| from core.prompt.utils.extract_thread_messages import extract_thread_messages | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| @@ -100,7 +100,7 @@ class TokenBufferMemory: | |||
| if not file_objs: | |||
| prompt_messages.append(UserPromptMessage(content=message.query)) | |||
| else: | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=message.query)) | |||
| for file in file_objs: | |||
| prompt_message = file_manager.to_prompt_message_content( | |||
| @@ -1,6 +1,6 @@ | |||
| from collections.abc import Sequence | |||
| from enum import Enum, StrEnum | |||
| from typing import Any, Optional, Union | |||
| from typing import Annotated, Any, Literal, Optional, Union | |||
| from pydantic import BaseModel, Field, field_serializer, field_validator | |||
| @@ -61,11 +61,7 @@ class PromptMessageContentType(StrEnum): | |||
| class PromptMessageContent(BaseModel): | |||
| """ | |||
| Model class for prompt message content. | |||
| """ | |||
| type: PromptMessageContentType | |||
| pass | |||
| class TextPromptMessageContent(PromptMessageContent): | |||
| @@ -73,7 +69,7 @@ class TextPromptMessageContent(PromptMessageContent): | |||
| Model class for text prompt message content. | |||
| """ | |||
| type: PromptMessageContentType = PromptMessageContentType.TEXT | |||
| type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT | |||
| data: str | |||
| @@ -82,7 +78,6 @@ class MultiModalPromptMessageContent(PromptMessageContent): | |||
| Model class for multi-modal prompt message content. | |||
| """ | |||
| type: PromptMessageContentType | |||
| format: str = Field(default=..., description="the format of multi-modal file") | |||
| base64_data: str = Field(default="", description="the base64 data of multi-modal file") | |||
| url: str = Field(default="", description="the url of multi-modal file") | |||
| @@ -94,11 +89,11 @@ class MultiModalPromptMessageContent(PromptMessageContent): | |||
| class VideoPromptMessageContent(MultiModalPromptMessageContent): | |||
| type: PromptMessageContentType = PromptMessageContentType.VIDEO | |||
| type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO | |||
| class AudioPromptMessageContent(MultiModalPromptMessageContent): | |||
| type: PromptMessageContentType = PromptMessageContentType.AUDIO | |||
| type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO | |||
| class ImagePromptMessageContent(MultiModalPromptMessageContent): | |||
| @@ -110,12 +105,24 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent): | |||
| LOW = "low" | |||
| HIGH = "high" | |||
| type: PromptMessageContentType = PromptMessageContentType.IMAGE | |||
| type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE | |||
| detail: DETAIL = DETAIL.LOW | |||
| class DocumentPromptMessageContent(MultiModalPromptMessageContent): | |||
| type: PromptMessageContentType = PromptMessageContentType.DOCUMENT | |||
| type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT | |||
| PromptMessageContentUnionTypes = Annotated[ | |||
| Union[ | |||
| TextPromptMessageContent, | |||
| ImagePromptMessageContent, | |||
| DocumentPromptMessageContent, | |||
| AudioPromptMessageContent, | |||
| VideoPromptMessageContent, | |||
| ], | |||
| Field(discriminator="type"), | |||
| ] | |||
| class PromptMessage(BaseModel): | |||
| @@ -124,7 +131,7 @@ class PromptMessage(BaseModel): | |||
| """ | |||
| role: PromptMessageRole | |||
| content: Optional[str | Sequence[PromptMessageContent]] = None | |||
| content: Optional[str | list[PromptMessageContentUnionTypes]] = None | |||
| name: Optional[str] = None | |||
| def is_empty(self) -> bool: | |||
| @@ -9,13 +9,12 @@ from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities import ( | |||
| AssistantPromptMessage, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageRole, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes | |||
| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig | |||
| from core.prompt.prompt_transform import PromptTransform | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| @@ -125,7 +124,7 @@ class AdvancedPromptTransform(PromptTransform): | |||
| prompt = Jinja2Formatter.format(prompt, prompt_inputs) | |||
| if files: | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=prompt)) | |||
| for file in files: | |||
| prompt_message_contents.append( | |||
| @@ -201,7 +200,7 @@ class AdvancedPromptTransform(PromptTransform): | |||
| prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) | |||
| if files and query is not None: | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||
| for file in files: | |||
| prompt_message_contents.append( | |||
| @@ -11,7 +11,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.message_entities import ( | |||
| ImagePromptMessageContent, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageContentUnionTypes, | |||
| SystemPromptMessage, | |||
| TextPromptMessageContent, | |||
| UserPromptMessage, | |||
| @@ -277,7 +277,7 @@ class SimplePromptTransform(PromptTransform): | |||
| image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, | |||
| ) -> UserPromptMessage: | |||
| if files: | |||
| prompt_message_contents: list[PromptMessageContent] = [] | |||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||
| prompt_message_contents.append(TextPromptMessageContent(data=prompt)) | |||
| for file in files: | |||
| prompt_message_contents.append( | |||
| @@ -24,7 +24,7 @@ from core.model_runtime.entities import ( | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||
| from core.model_runtime.entities.message_entities import ( | |||
| AssistantPromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageContentUnionTypes, | |||
| PromptMessageRole, | |||
| SystemPromptMessage, | |||
| UserPromptMessage, | |||
| @@ -594,8 +594,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| variable_pool: VariablePool, | |||
| jinja2_variables: Sequence[VariableSelector], | |||
| ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: | |||
| # FIXME: fix the type error cause prompt_messages is type quick a few times | |||
| prompt_messages: list[Any] = [] | |||
| prompt_messages: list[PromptMessage] = [] | |||
| if isinstance(prompt_template, list): | |||
| # For chat model | |||
| @@ -657,12 +656,14 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| # For issue #11247 - Check if prompt content is a string or a list | |||
| prompt_content_type = type(prompt_content) | |||
| if prompt_content_type == str: | |||
| prompt_content = str(prompt_content) | |||
| if "#histories#" in prompt_content: | |||
| prompt_content = prompt_content.replace("#histories#", memory_text) | |||
| else: | |||
| prompt_content = memory_text + "\n" + prompt_content | |||
| prompt_messages[0].content = prompt_content | |||
| elif prompt_content_type == list: | |||
| prompt_content = prompt_content if isinstance(prompt_content, list) else [] | |||
| for content_item in prompt_content: | |||
| if content_item.type == PromptMessageContentType.TEXT: | |||
| if "#histories#" in content_item.data: | |||
| @@ -675,9 +676,10 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| # Add current query to the prompt message | |||
| if sys_query: | |||
| if prompt_content_type == str: | |||
| prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query) | |||
| prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) | |||
| prompt_messages[0].content = prompt_content | |||
| elif prompt_content_type == list: | |||
| prompt_content = prompt_content if isinstance(prompt_content, list) else [] | |||
| for content_item in prompt_content: | |||
| if content_item.type == PromptMessageContentType.TEXT: | |||
| content_item.data = sys_query + "\n" + content_item.data | |||
| @@ -707,7 +709,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| filtered_prompt_messages = [] | |||
| for prompt_message in prompt_messages: | |||
| if isinstance(prompt_message.content, list): | |||
| prompt_message_content = [] | |||
| prompt_message_content: list[PromptMessageContentUnionTypes] = [] | |||
| for content_item in prompt_message.content: | |||
| # Skip content if features are not defined | |||
| if not model_config.model_schema.features: | |||
| @@ -1132,7 +1134,9 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| ) | |||
| def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): | |||
| def _combine_message_content_with_role( | |||
| *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole | |||
| ): | |||
| match role: | |||
| case PromptMessageRole.USER: | |||
| return UserPromptMessage(content=contents) | |||
| @@ -0,0 +1,27 @@ | |||
| from core.model_runtime.entities.message_entities import ( | |||
| ImagePromptMessageContent, | |||
| TextPromptMessageContent, | |||
| UserPromptMessage, | |||
| ) | |||
| def test_build_prompt_message_with_prompt_message_contents(): | |||
| prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) | |||
| assert isinstance(prompt.content, list) | |||
| assert isinstance(prompt.content[0], TextPromptMessageContent) | |||
| assert prompt.content[0].data == "Hello, World!" | |||
| def test_dump_prompt_message(): | |||
| example_url = "https://example.com/image.jpg" | |||
| prompt = UserPromptMessage( | |||
| content=[ | |||
| ImagePromptMessageContent( | |||
| url=example_url, | |||
| format="jpeg", | |||
| mime_type="image/jpeg", | |||
| ) | |||
| ] | |||
| ) | |||
| data = prompt.model_dump() | |||
| assert data["content"][0].get("url") == example_url | |||