|
|
|
@@ -1,19 +1,8 @@ |
|
|
|
import enum |
|
|
|
from typing import Any, cast |
|
|
|
from typing import Any |
|
|
|
|
|
|
|
from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage |
|
|
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
from core.model_runtime.entities.message_entities import ( |
|
|
|
AssistantPromptMessage, |
|
|
|
ImagePromptMessageContent, |
|
|
|
PromptMessage, |
|
|
|
SystemPromptMessage, |
|
|
|
TextPromptMessageContent, |
|
|
|
ToolPromptMessage, |
|
|
|
UserPromptMessage, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class PromptMessageFileType(enum.Enum): |
|
|
|
IMAGE = 'image' |
|
|
|
@@ -38,98 +27,3 @@ class ImagePromptMessageFile(PromptMessageFile): |
|
|
|
|
|
|
|
type: PromptMessageFileType = PromptMessageFileType.IMAGE |
|
|
|
detail: DETAIL = DETAIL.LOW |
|
|
|
|
|
|
|
|
|
|
|
class LCHumanMessageWithFiles(HumanMessage): |
|
|
|
# content: Union[str, list[Union[str, Dict]]] |
|
|
|
content: str |
|
|
|
files: list[PromptMessageFile] |
|
|
|
|
|
|
|
|
|
|
|
def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]: |
|
|
|
prompt_messages = [] |
|
|
|
for message in messages: |
|
|
|
if isinstance(message, HumanMessage): |
|
|
|
if isinstance(message, LCHumanMessageWithFiles): |
|
|
|
file_prompt_message_contents = [] |
|
|
|
for file in message.files: |
|
|
|
if file.type == PromptMessageFileType.IMAGE: |
|
|
|
file = cast(ImagePromptMessageFile, file) |
|
|
|
file_prompt_message_contents.append(ImagePromptMessageContent( |
|
|
|
data=file.data, |
|
|
|
detail=ImagePromptMessageContent.DETAIL.HIGH |
|
|
|
if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW |
|
|
|
)) |
|
|
|
|
|
|
|
prompt_message_contents = [TextPromptMessageContent(data=message.content)] |
|
|
|
prompt_message_contents.extend(file_prompt_message_contents) |
|
|
|
|
|
|
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) |
|
|
|
else: |
|
|
|
prompt_messages.append(UserPromptMessage(content=message.content)) |
|
|
|
elif isinstance(message, AIMessage): |
|
|
|
message_kwargs = { |
|
|
|
'content': message.content |
|
|
|
} |
|
|
|
|
|
|
|
if 'function_call' in message.additional_kwargs: |
|
|
|
message_kwargs['tool_calls'] = [ |
|
|
|
AssistantPromptMessage.ToolCall( |
|
|
|
id=message.additional_kwargs['function_call']['id'], |
|
|
|
type='function', |
|
|
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name=message.additional_kwargs['function_call']['name'], |
|
|
|
arguments=message.additional_kwargs['function_call']['arguments'] |
|
|
|
) |
|
|
|
) |
|
|
|
] |
|
|
|
|
|
|
|
prompt_messages.append(AssistantPromptMessage(**message_kwargs)) |
|
|
|
elif isinstance(message, SystemMessage): |
|
|
|
prompt_messages.append(SystemPromptMessage(content=message.content)) |
|
|
|
elif isinstance(message, FunctionMessage): |
|
|
|
prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name)) |
|
|
|
|
|
|
|
return prompt_messages |
|
|
|
|
|
|
|
|
|
|
|
def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]: |
|
|
|
messages = [] |
|
|
|
for prompt_message in prompt_messages: |
|
|
|
if isinstance(prompt_message, UserPromptMessage): |
|
|
|
if isinstance(prompt_message.content, str): |
|
|
|
messages.append(HumanMessage(content=prompt_message.content)) |
|
|
|
else: |
|
|
|
message_contents = [] |
|
|
|
for content in prompt_message.content: |
|
|
|
if isinstance(content, TextPromptMessageContent): |
|
|
|
message_contents.append(content.data) |
|
|
|
elif isinstance(content, ImagePromptMessageContent): |
|
|
|
message_contents.append({ |
|
|
|
'type': 'image', |
|
|
|
'data': content.data, |
|
|
|
'detail': content.detail.value |
|
|
|
}) |
|
|
|
|
|
|
|
messages.append(HumanMessage(content=message_contents)) |
|
|
|
elif isinstance(prompt_message, AssistantPromptMessage): |
|
|
|
message_kwargs = { |
|
|
|
'content': prompt_message.content |
|
|
|
} |
|
|
|
|
|
|
|
if prompt_message.tool_calls: |
|
|
|
message_kwargs['additional_kwargs'] = { |
|
|
|
'function_call': { |
|
|
|
'id': prompt_message.tool_calls[0].id, |
|
|
|
'name': prompt_message.tool_calls[0].function.name, |
|
|
|
'arguments': prompt_message.tool_calls[0].function.arguments |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
messages.append(AIMessage(**message_kwargs)) |
|
|
|
elif isinstance(prompt_message, SystemPromptMessage): |
|
|
|
messages.append(SystemMessage(content=prompt_message.content)) |
|
|
|
elif isinstance(prompt_message, ToolPromptMessage): |
|
|
|
messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content)) |
|
|
|
|
|
|
|
return messages |