|
|
|
@@ -20,6 +20,7 @@ from core.model_runtime.entities import ( |
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage |
|
|
|
from core.model_runtime.entities.message_entities import ( |
|
|
|
AssistantPromptMessage, |
|
|
|
PromptMessageContent, |
|
|
|
PromptMessageRole, |
|
|
|
SystemPromptMessage, |
|
|
|
UserPromptMessage, |
|
|
|
@@ -828,14 +829,14 @@ class LLMNode(BaseNode[LLMNodeData]): |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): |
|
|
|
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): |
|
|
|
match role: |
|
|
|
case PromptMessageRole.USER: |
|
|
|
return UserPromptMessage(content=[TextPromptMessageContent(data=text)]) |
|
|
|
return UserPromptMessage(content=contents) |
|
|
|
case PromptMessageRole.ASSISTANT: |
|
|
|
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)]) |
|
|
|
return AssistantPromptMessage(content=contents) |
|
|
|
case PromptMessageRole.SYSTEM: |
|
|
|
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)]) |
|
|
|
return SystemPromptMessage(content=contents) |
|
|
|
raise NotImplementedError(f"Role {role} is not supported") |
|
|
|
|
|
|
|
|
|
|
|
@@ -877,7 +878,9 @@ def _handle_list_messages( |
|
|
|
jinjia2_variables=jinja2_variables, |
|
|
|
variable_pool=variable_pool, |
|
|
|
) |
|
|
|
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) |
|
|
|
prompt_message = _combine_message_content_with_role( |
|
|
|
contents=[TextPromptMessageContent(data=result_text)], role=message.role |
|
|
|
) |
|
|
|
prompt_messages.append(prompt_message) |
|
|
|
else: |
|
|
|
# Get segment group from basic message |
|
|
|
@@ -908,12 +911,14 @@ def _handle_list_messages( |
|
|
|
# Create message with text from all segments |
|
|
|
plain_text = segment_group.text |
|
|
|
if plain_text: |
|
|
|
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) |
|
|
|
prompt_message = _combine_message_content_with_role( |
|
|
|
contents=[TextPromptMessageContent(data=plain_text)], role=message.role |
|
|
|
) |
|
|
|
prompt_messages.append(prompt_message) |
|
|
|
|
|
|
|
if file_contents: |
|
|
|
# Create message with image contents |
|
|
|
prompt_message = UserPromptMessage(content=file_contents) |
|
|
|
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) |
|
|
|
prompt_messages.append(prompt_message) |
|
|
|
|
|
|
|
return prompt_messages |
|
|
|
@@ -1018,6 +1023,8 @@ def _handle_completion_template( |
|
|
|
else: |
|
|
|
template_text = template.text |
|
|
|
result_text = variable_pool.convert_template(template_text).text |
|
|
|
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) |
|
|
|
prompt_message = _combine_message_content_with_role( |
|
|
|
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER |
|
|
|
) |
|
|
|
prompt_messages.append(prompt_message) |
|
|
|
return prompt_messages |