Co-authored-by: 朱庆超 <zhuqingchao@xiaomi.com> Co-authored-by: crazywoola <427733928@qq.com>tags/1.3.0
| AssistantPromptMessage, | AssistantPromptMessage, | ||||
| LLMUsage, | LLMUsage, | ||||
| PromptMessage, | PromptMessage, | ||||
| PromptMessageContent, | |||||
| PromptMessageTool, | PromptMessageTool, | ||||
| SystemPromptMessage, | SystemPromptMessage, | ||||
| TextPromptMessageContent, | TextPromptMessageContent, | ||||
| ToolPromptMessage, | ToolPromptMessage, | ||||
| UserPromptMessage, | UserPromptMessage, | ||||
| ) | ) | ||||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes | |||||
| from core.model_runtime.entities.model_entities import ModelFeature | from core.model_runtime.entities.model_entities import ModelFeature | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.prompt.utils.extract_thread_messages import extract_thread_messages | from core.prompt.utils.extract_thread_messages import extract_thread_messages | ||||
| ) | ) | ||||
| if not file_objs: | if not file_objs: | ||||
| return UserPromptMessage(content=message.query) | return UserPromptMessage(content=message.query) | ||||
| prompt_message_contents: list[PromptMessageContent] = [] | |||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| prompt_message_contents.append(TextPromptMessageContent(data=message.query)) | prompt_message_contents.append(TextPromptMessageContent(data=message.query)) | ||||
| for file in file_objs: | for file in file_objs: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( |
| from core.model_runtime.entities import ( | from core.model_runtime.entities import ( | ||||
| AssistantPromptMessage, | AssistantPromptMessage, | ||||
| PromptMessage, | PromptMessage, | ||||
| PromptMessageContent, | |||||
| SystemPromptMessage, | SystemPromptMessage, | ||||
| TextPromptMessageContent, | TextPromptMessageContent, | ||||
| UserPromptMessage, | UserPromptMessage, | ||||
| ) | ) | ||||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes | |||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| Organize user query | Organize user query | ||||
| """ | """ | ||||
| if self.files: | if self.files: | ||||
| prompt_message_contents: list[PromptMessageContent] = [] | |||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | prompt_message_contents.append(TextPromptMessageContent(data=query)) | ||||
| # get image detail config | # get image detail config |
| LLMResultChunkDelta, | LLMResultChunkDelta, | ||||
| LLMUsage, | LLMUsage, | ||||
| PromptMessage, | PromptMessage, | ||||
| PromptMessageContent, | |||||
| PromptMessageContentType, | PromptMessageContentType, | ||||
| SystemPromptMessage, | SystemPromptMessage, | ||||
| TextPromptMessageContent, | TextPromptMessageContent, | ||||
| ToolPromptMessage, | ToolPromptMessage, | ||||
| UserPromptMessage, | UserPromptMessage, | ||||
| ) | ) | ||||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes | |||||
| from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform | ||||
| from core.tools.entities.tool_entities import ToolInvokeMeta | from core.tools.entities.tool_entities import ToolInvokeMeta | ||||
| from core.tools.tool_engine import ToolEngine | from core.tools.tool_engine import ToolEngine | ||||
| Organize user query | Organize user query | ||||
| """ | """ | ||||
| if self.files: | if self.files: | ||||
| prompt_message_contents: list[PromptMessageContent] = [] | |||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | prompt_message_contents.append(TextPromptMessageContent(data=query)) | ||||
| # get image detail config | # get image detail config |
| AudioPromptMessageContent, | AudioPromptMessageContent, | ||||
| DocumentPromptMessageContent, | DocumentPromptMessageContent, | ||||
| ImagePromptMessageContent, | ImagePromptMessageContent, | ||||
| MultiModalPromptMessageContent, | |||||
| VideoPromptMessageContent, | VideoPromptMessageContent, | ||||
| ) | ) | ||||
| from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes | |||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| from . import helpers | from . import helpers | ||||
| /, | /, | ||||
| *, | *, | ||||
| image_detail_config: ImagePromptMessageContent.DETAIL | None = None, | image_detail_config: ImagePromptMessageContent.DETAIL | None = None, | ||||
| ) -> MultiModalPromptMessageContent: | |||||
| ) -> PromptMessageContentUnionTypes: | |||||
| if f.extension is None: | if f.extension is None: | ||||
| raise ValueError("Missing file extension") | raise ValueError("Missing file extension") | ||||
| if f.mime_type is None: | if f.mime_type is None: | ||||
| if f.type == FileType.IMAGE: | if f.type == FileType.IMAGE: | ||||
| params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW | params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW | ||||
| prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = { | |||||
| prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { | |||||
| FileType.IMAGE: ImagePromptMessageContent, | FileType.IMAGE: ImagePromptMessageContent, | ||||
| FileType.AUDIO: AudioPromptMessageContent, | FileType.AUDIO: AudioPromptMessageContent, | ||||
| FileType.VIDEO: VideoPromptMessageContent, | FileType.VIDEO: VideoPromptMessageContent, |
| AssistantPromptMessage, | AssistantPromptMessage, | ||||
| ImagePromptMessageContent, | ImagePromptMessageContent, | ||||
| PromptMessage, | PromptMessage, | ||||
| PromptMessageContent, | |||||
| PromptMessageRole, | PromptMessageRole, | ||||
| TextPromptMessageContent, | TextPromptMessageContent, | ||||
| UserPromptMessage, | UserPromptMessage, | ||||
| ) | ) | ||||
| from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes | |||||
| from core.prompt.utils.extract_thread_messages import extract_thread_messages | from core.prompt.utils.extract_thread_messages import extract_thread_messages | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| if not file_objs: | if not file_objs: | ||||
| prompt_messages.append(UserPromptMessage(content=message.query)) | prompt_messages.append(UserPromptMessage(content=message.query)) | ||||
| else: | else: | ||||
| prompt_message_contents: list[PromptMessageContent] = [] | |||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| prompt_message_contents.append(TextPromptMessageContent(data=message.query)) | prompt_message_contents.append(TextPromptMessageContent(data=message.query)) | ||||
| for file in file_objs: | for file in file_objs: | ||||
| prompt_message = file_manager.to_prompt_message_content( | prompt_message = file_manager.to_prompt_message_content( |
| from collections.abc import Sequence | from collections.abc import Sequence | ||||
| from enum import Enum, StrEnum | from enum import Enum, StrEnum | ||||
| from typing import Any, Optional, Union | |||||
| from typing import Annotated, Any, Literal, Optional, Union | |||||
| from pydantic import BaseModel, Field, field_serializer, field_validator | from pydantic import BaseModel, Field, field_serializer, field_validator | ||||
| class PromptMessageContent(BaseModel): | class PromptMessageContent(BaseModel): | ||||
| """ | |||||
| Model class for prompt message content. | |||||
| """ | |||||
| type: PromptMessageContentType | |||||
| pass | |||||
| class TextPromptMessageContent(PromptMessageContent): | class TextPromptMessageContent(PromptMessageContent): | ||||
| Model class for text prompt message content. | Model class for text prompt message content. | ||||
| """ | """ | ||||
| type: PromptMessageContentType = PromptMessageContentType.TEXT | |||||
| type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT | |||||
| data: str | data: str | ||||
| Model class for multi-modal prompt message content. | Model class for multi-modal prompt message content. | ||||
| """ | """ | ||||
| type: PromptMessageContentType | |||||
| format: str = Field(default=..., description="the format of multi-modal file") | format: str = Field(default=..., description="the format of multi-modal file") | ||||
| base64_data: str = Field(default="", description="the base64 data of multi-modal file") | base64_data: str = Field(default="", description="the base64 data of multi-modal file") | ||||
| url: str = Field(default="", description="the url of multi-modal file") | url: str = Field(default="", description="the url of multi-modal file") | ||||
| class VideoPromptMessageContent(MultiModalPromptMessageContent): | class VideoPromptMessageContent(MultiModalPromptMessageContent): | ||||
| type: PromptMessageContentType = PromptMessageContentType.VIDEO | |||||
| type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO | |||||
| class AudioPromptMessageContent(MultiModalPromptMessageContent): | class AudioPromptMessageContent(MultiModalPromptMessageContent): | ||||
| type: PromptMessageContentType = PromptMessageContentType.AUDIO | |||||
| type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO | |||||
| class ImagePromptMessageContent(MultiModalPromptMessageContent): | class ImagePromptMessageContent(MultiModalPromptMessageContent): | ||||
| LOW = "low" | LOW = "low" | ||||
| HIGH = "high" | HIGH = "high" | ||||
| type: PromptMessageContentType = PromptMessageContentType.IMAGE | |||||
| type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE | |||||
| detail: DETAIL = DETAIL.LOW | detail: DETAIL = DETAIL.LOW | ||||
| class DocumentPromptMessageContent(MultiModalPromptMessageContent): | class DocumentPromptMessageContent(MultiModalPromptMessageContent): | ||||
| type: PromptMessageContentType = PromptMessageContentType.DOCUMENT | |||||
| type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT | |||||
| PromptMessageContentUnionTypes = Annotated[ | |||||
| Union[ | |||||
| TextPromptMessageContent, | |||||
| ImagePromptMessageContent, | |||||
| DocumentPromptMessageContent, | |||||
| AudioPromptMessageContent, | |||||
| VideoPromptMessageContent, | |||||
| ], | |||||
| Field(discriminator="type"), | |||||
| ] | |||||
| class PromptMessage(BaseModel): | class PromptMessage(BaseModel): | ||||
| """ | """ | ||||
| role: PromptMessageRole | role: PromptMessageRole | ||||
| content: Optional[str | Sequence[PromptMessageContent]] = None | |||||
| content: Optional[str | list[PromptMessageContentUnionTypes]] = None | |||||
| name: Optional[str] = None | name: Optional[str] = None | ||||
| def is_empty(self) -> bool: | def is_empty(self) -> bool: |
| from core.model_runtime.entities import ( | from core.model_runtime.entities import ( | ||||
| AssistantPromptMessage, | AssistantPromptMessage, | ||||
| PromptMessage, | PromptMessage, | ||||
| PromptMessageContent, | |||||
| PromptMessageRole, | PromptMessageRole, | ||||
| SystemPromptMessage, | SystemPromptMessage, | ||||
| TextPromptMessageContent, | TextPromptMessageContent, | ||||
| UserPromptMessage, | UserPromptMessage, | ||||
| ) | ) | ||||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes | |||||
| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig | from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig | ||||
| from core.prompt.prompt_transform import PromptTransform | from core.prompt.prompt_transform import PromptTransform | ||||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | from core.prompt.utils.prompt_template_parser import PromptTemplateParser | ||||
| prompt = Jinja2Formatter.format(prompt, prompt_inputs) | prompt = Jinja2Formatter.format(prompt, prompt_inputs) | ||||
| if files: | if files: | ||||
| prompt_message_contents: list[PromptMessageContent] = [] | |||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| prompt_message_contents.append(TextPromptMessageContent(data=prompt)) | prompt_message_contents.append(TextPromptMessageContent(data=prompt)) | ||||
| for file in files: | for file in files: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( | ||||
| prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) | prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) | ||||
| if files and query is not None: | if files and query is not None: | ||||
| prompt_message_contents: list[PromptMessageContent] = [] | |||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | prompt_message_contents.append(TextPromptMessageContent(data=query)) | ||||
| for file in files: | for file in files: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( |
| from core.model_runtime.entities.message_entities import ( | from core.model_runtime.entities.message_entities import ( | ||||
| ImagePromptMessageContent, | ImagePromptMessageContent, | ||||
| PromptMessage, | PromptMessage, | ||||
| PromptMessageContent, | |||||
| PromptMessageContentUnionTypes, | |||||
| SystemPromptMessage, | SystemPromptMessage, | ||||
| TextPromptMessageContent, | TextPromptMessageContent, | ||||
| UserPromptMessage, | UserPromptMessage, | ||||
| image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, | image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, | ||||
| ) -> UserPromptMessage: | ) -> UserPromptMessage: | ||||
| if files: | if files: | ||||
| prompt_message_contents: list[PromptMessageContent] = [] | |||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| prompt_message_contents.append(TextPromptMessageContent(data=prompt)) | prompt_message_contents.append(TextPromptMessageContent(data=prompt)) | ||||
| for file in files: | for file in files: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( |
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | ||||
| from core.model_runtime.entities.message_entities import ( | from core.model_runtime.entities.message_entities import ( | ||||
| AssistantPromptMessage, | AssistantPromptMessage, | ||||
| PromptMessageContent, | |||||
| PromptMessageContentUnionTypes, | |||||
| PromptMessageRole, | PromptMessageRole, | ||||
| SystemPromptMessage, | SystemPromptMessage, | ||||
| UserPromptMessage, | UserPromptMessage, | ||||
| variable_pool: VariablePool, | variable_pool: VariablePool, | ||||
| jinja2_variables: Sequence[VariableSelector], | jinja2_variables: Sequence[VariableSelector], | ||||
| ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: | ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: | ||||
| # FIXME: fix the type error cause prompt_messages is type quick a few times | |||||
| prompt_messages: list[Any] = [] | |||||
| prompt_messages: list[PromptMessage] = [] | |||||
| if isinstance(prompt_template, list): | if isinstance(prompt_template, list): | ||||
| # For chat model | # For chat model | ||||
| # For issue #11247 - Check if prompt content is a string or a list | # For issue #11247 - Check if prompt content is a string or a list | ||||
| prompt_content_type = type(prompt_content) | prompt_content_type = type(prompt_content) | ||||
| if prompt_content_type == str: | if prompt_content_type == str: | ||||
| prompt_content = str(prompt_content) | |||||
| if "#histories#" in prompt_content: | if "#histories#" in prompt_content: | ||||
| prompt_content = prompt_content.replace("#histories#", memory_text) | prompt_content = prompt_content.replace("#histories#", memory_text) | ||||
| else: | else: | ||||
| prompt_content = memory_text + "\n" + prompt_content | prompt_content = memory_text + "\n" + prompt_content | ||||
| prompt_messages[0].content = prompt_content | prompt_messages[0].content = prompt_content | ||||
| elif prompt_content_type == list: | elif prompt_content_type == list: | ||||
| prompt_content = prompt_content if isinstance(prompt_content, list) else [] | |||||
| for content_item in prompt_content: | for content_item in prompt_content: | ||||
| if content_item.type == PromptMessageContentType.TEXT: | if content_item.type == PromptMessageContentType.TEXT: | ||||
| if "#histories#" in content_item.data: | if "#histories#" in content_item.data: | ||||
| # Add current query to the prompt message | # Add current query to the prompt message | ||||
| if sys_query: | if sys_query: | ||||
| if prompt_content_type == str: | if prompt_content_type == str: | ||||
| prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query) | |||||
| prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) | |||||
| prompt_messages[0].content = prompt_content | prompt_messages[0].content = prompt_content | ||||
| elif prompt_content_type == list: | elif prompt_content_type == list: | ||||
| prompt_content = prompt_content if isinstance(prompt_content, list) else [] | |||||
| for content_item in prompt_content: | for content_item in prompt_content: | ||||
| if content_item.type == PromptMessageContentType.TEXT: | if content_item.type == PromptMessageContentType.TEXT: | ||||
| content_item.data = sys_query + "\n" + content_item.data | content_item.data = sys_query + "\n" + content_item.data | ||||
| filtered_prompt_messages = [] | filtered_prompt_messages = [] | ||||
| for prompt_message in prompt_messages: | for prompt_message in prompt_messages: | ||||
| if isinstance(prompt_message.content, list): | if isinstance(prompt_message.content, list): | ||||
| prompt_message_content = [] | |||||
| prompt_message_content: list[PromptMessageContentUnionTypes] = [] | |||||
| for content_item in prompt_message.content: | for content_item in prompt_message.content: | ||||
| # Skip content if features are not defined | # Skip content if features are not defined | ||||
| if not model_config.model_schema.features: | if not model_config.model_schema.features: | ||||
| ) | ) | ||||
| def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): | |||||
| def _combine_message_content_with_role( | |||||
| *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole | |||||
| ): | |||||
| match role: | match role: | ||||
| case PromptMessageRole.USER: | case PromptMessageRole.USER: | ||||
| return UserPromptMessage(content=contents) | return UserPromptMessage(content=contents) |
| from core.model_runtime.entities.message_entities import ( | |||||
| ImagePromptMessageContent, | |||||
| TextPromptMessageContent, | |||||
| UserPromptMessage, | |||||
| ) | |||||
| def test_build_prompt_message_with_prompt_message_contents(): | |||||
| prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) | |||||
| assert isinstance(prompt.content, list) | |||||
| assert isinstance(prompt.content[0], TextPromptMessageContent) | |||||
| assert prompt.content[0].data == "Hello, World!" | |||||
| def test_dump_prompt_message(): | |||||
| example_url = "https://example.com/image.jpg" | |||||
| prompt = UserPromptMessage( | |||||
| content=[ | |||||
| ImagePromptMessageContent( | |||||
| url=example_url, | |||||
| format="jpeg", | |||||
| mime_type="image/jpeg", | |||||
| ) | |||||
| ] | |||||
| ) | |||||
| data = prompt.model_dump() | |||||
| assert data["content"][0].get("url") == example_url |