Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>tags/1.8.0
| if not file_objs: | if not file_objs: | ||||
| return UserPromptMessage(content=message.query) | return UserPromptMessage(content=message.query) | ||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | ||||
| prompt_message_contents.append(TextPromptMessageContent(data=message.query)) | |||||
| for file in file_objs: | for file in file_objs: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( | ||||
| file_manager.to_prompt_message_content( | file_manager.to_prompt_message_content( | ||||
| image_detail_config=image_detail_config, | image_detail_config=image_detail_config, | ||||
| ) | ) | ||||
| ) | ) | ||||
| prompt_message_contents.append(TextPromptMessageContent(data=message.query)) | |||||
| return UserPromptMessage(content=prompt_message_contents) | return UserPromptMessage(content=prompt_message_contents) |
| Organize user query | Organize user query | ||||
| """ | """ | ||||
| if self.files: | if self.files: | ||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||||
| # get image detail config | # get image detail config | ||||
| image_detail_config = ( | image_detail_config = ( | ||||
| self.application_generate_entity.file_upload_config.image_config.detail | self.application_generate_entity.file_upload_config.image_config.detail | ||||
| else None | else None | ||||
| ) | ) | ||||
| image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW | image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW | ||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| for file in self.files: | for file in self.files: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( | ||||
| file_manager.to_prompt_message_content( | file_manager.to_prompt_message_content( | ||||
| image_detail_config=image_detail_config, | image_detail_config=image_detail_config, | ||||
| ) | ) | ||||
| ) | ) | ||||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | ||||
| else: | else: |
| Organize user query | Organize user query | ||||
| """ | """ | ||||
| if self.files: | if self.files: | ||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||||
| # get image detail config | # get image detail config | ||||
| image_detail_config = ( | image_detail_config = ( | ||||
| self.application_generate_entity.file_upload_config.image_config.detail | self.application_generate_entity.file_upload_config.image_config.detail | ||||
| else None | else None | ||||
| ) | ) | ||||
| image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW | image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW | ||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| for file in self.files: | for file in self.files: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( | ||||
| file_manager.to_prompt_message_content( | file_manager.to_prompt_message_content( | ||||
| image_detail_config=image_detail_config, | image_detail_config=image_detail_config, | ||||
| ) | ) | ||||
| ) | ) | ||||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | ||||
| else: | else: |
| if files: | if files: | ||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | ||||
| prompt_message_contents.append(TextPromptMessageContent(data=prompt)) | |||||
| for file in files: | for file in files: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( | ||||
| file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) | file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) | ||||
| ) | ) | ||||
| prompt_message_contents.append(TextPromptMessageContent(data=prompt)) | |||||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | ||||
| else: | else: | ||||
| query = parser.format(prompt_inputs) | query = parser.format(prompt_inputs) | ||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| if memory and memory_config: | if memory and memory_config: | ||||
| prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) | prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) | ||||
| if files and query is not None: | if files and query is not None: | ||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | |||||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||||
| for file in files: | for file in files: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( | ||||
| file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) | file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) | ||||
| ) | ) | ||||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | ||||
| else: | else: | ||||
| prompt_messages.append(UserPromptMessage(content=query)) | prompt_messages.append(UserPromptMessage(content=query)) | ||||
| last_message = prompt_messages[-1] if prompt_messages else None | last_message = prompt_messages[-1] if prompt_messages else None | ||||
| if last_message and last_message.role == PromptMessageRole.USER: | if last_message and last_message.role == PromptMessageRole.USER: | ||||
| # get last user message content and add files | # get last user message content and add files | ||||
| prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))] | |||||
| for file in files: | for file in files: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( | ||||
| file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) | file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) | ||||
| ) | ) | ||||
| prompt_message_contents.append(TextPromptMessageContent(data=cast(str, last_message.content))) | |||||
| last_message.content = prompt_message_contents | last_message.content = prompt_message_contents | ||||
| else: | else: | ||||
| prompt_message_contents = [TextPromptMessageContent(data="")] # not for query | |||||
| for file in files: | for file in files: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( | ||||
| file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) | file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) | ||||
| ) | ) | ||||
| prompt_message_contents.append(TextPromptMessageContent(data="")) | |||||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | ||||
| else: | else: | ||||
| prompt_message_contents = [TextPromptMessageContent(data=query)] | |||||
| for file in files: | for file in files: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( | ||||
| file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) | file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) | ||||
| ) | ) | ||||
| prompt_message_contents.append(TextPromptMessageContent(data=query)) | |||||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | ||||
| elif query: | elif query: |
| ) -> UserPromptMessage: | ) -> UserPromptMessage: | ||||
| if files: | if files: | ||||
| prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | prompt_message_contents: list[PromptMessageContentUnionTypes] = [] | ||||
| prompt_message_contents.append(TextPromptMessageContent(data=prompt)) | |||||
| for file in files: | for file in files: | ||||
| prompt_message_contents.append( | prompt_message_contents.append( | ||||
| file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) | file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) | ||||
| ) | ) | ||||
| prompt_message_contents.append(TextPromptMessageContent(data=prompt)) | |||||
| prompt_message = UserPromptMessage(content=prompt_message_contents) | prompt_message = UserPromptMessage(content=prompt_message_contents) | ||||
| else: | else: |
| ) | ) | ||||
| assert isinstance(prompt_messages[3].content, list) | assert isinstance(prompt_messages[3].content, list) | ||||
| assert len(prompt_messages[3].content) == 2 | assert len(prompt_messages[3].content) == 2 | ||||
| assert prompt_messages[3].content[1].data == files[0].remote_url | |||||
| assert prompt_messages[3].content[0].data == files[0].remote_url | |||||
| @pytest.fixture | @pytest.fixture |