| @@ -2,7 +2,7 @@ import logging | |||
| import time | |||
| import uuid | |||
| from collections.abc import Generator, Sequence | |||
| from typing import Optional, Union | |||
| from typing import Optional, Union, cast | |||
| from pydantic import ConfigDict | |||
| @@ -20,6 +20,7 @@ from core.model_runtime.entities.model_entities import ( | |||
| PriceType, | |||
| ) | |||
| from core.model_runtime.model_providers.__base.ai_model import AIModel | |||
| from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str | |||
| from core.plugin.manager.model import PluginModelManager | |||
| logger = logging.getLogger(__name__) | |||
| @@ -280,7 +281,9 @@ class LargeLanguageModel(AIModel): | |||
| callbacks=callbacks, | |||
| ) | |||
| assistant_message.content += chunk.delta.message.content | |||
| text = convert_llm_result_chunk_to_str(chunk.delta.message.content) | |||
| current_content = cast(str, assistant_message.content) | |||
| assistant_message.content = current_content + text | |||
| real_model = chunk.model | |||
| if chunk.delta.usage: | |||
| usage = chunk.delta.usage | |||
| @@ -1,6 +1,8 @@ | |||
| import pydantic | |||
| from pydantic import BaseModel | |||
| from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes | |||
| def dump_model(model: BaseModel) -> dict: | |||
| if hasattr(pydantic, "model_dump"): | |||
| @@ -8,3 +10,18 @@ def dump_model(model: BaseModel) -> dict: | |||
| return pydantic.model_dump(model) # type: ignore | |||
| else: | |||
| return model.model_dump() | |||
| def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str: | |||
| if content is None: | |||
| message_text = "" | |||
| elif isinstance(content, str): | |||
| message_text = content | |||
| elif isinstance(content, list): | |||
| # Assuming the list contains PromptMessageContent objects with a "data" attribute | |||
| message_text = "".join( | |||
| item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content | |||
| ) | |||
| else: | |||
| message_text = str(content) | |||
| return message_text | |||
| @@ -38,6 +38,7 @@ from core.model_runtime.entities.model_entities import ( | |||
| ) | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| @@ -269,18 +270,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: | |||
| if isinstance(invoke_result, LLMResult): | |||
| content = invoke_result.message.content | |||
| if content is None: | |||
| message_text = "" | |||
| elif isinstance(content, str): | |||
| message_text = content | |||
| elif isinstance(content, list): | |||
| # Assuming the list contains PromptMessageContent objects with a "data" attribute | |||
| message_text = "".join( | |||
| item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content | |||
| ) | |||
| else: | |||
| message_text = str(content) | |||
| message_text = convert_llm_result_chunk_to_str(invoke_result.message.content) | |||
| yield ModelInvokeCompletedEvent( | |||
| text=message_text, | |||
| @@ -295,7 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| usage = None | |||
| finish_reason = None | |||
| for result in invoke_result: | |||
| text = result.delta.message.content | |||
| text = convert_llm_result_chunk_to_str(result.delta.message.content) | |||
| full_text += text | |||
| yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) | |||