Browse Source

fix: LLMResultChunk cause concatenate str and list exception (#18852)

tags/1.3.1
非法操作 6 months ago
parent
commit
c1559a7c8e
No account linked to committer's email address

+ 5
- 2
api/core/model_runtime/model_providers/__base/large_language_model.py View File

import time import time
import uuid import uuid
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from typing import Optional, Union
from typing import Optional, Union, cast


from pydantic import ConfigDict from pydantic import ConfigDict


PriceType, PriceType,
) )
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
from core.plugin.manager.model import PluginModelManager from core.plugin.manager.model import PluginModelManager


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
callbacks=callbacks, callbacks=callbacks,
) )


assistant_message.content += chunk.delta.message.content
text = convert_llm_result_chunk_to_str(chunk.delta.message.content)
current_content = cast(str, assistant_message.content)
assistant_message.content = current_content + text
real_model = chunk.model real_model = chunk.model
if chunk.delta.usage: if chunk.delta.usage:
usage = chunk.delta.usage usage = chunk.delta.usage

+ 17
- 0
api/core/model_runtime/utils/helper.py View File

import pydantic import pydantic
from pydantic import BaseModel from pydantic import BaseModel


from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes



def dump_model(model: BaseModel) -> dict: def dump_model(model: BaseModel) -> dict:
if hasattr(pydantic, "model_dump"): if hasattr(pydantic, "model_dump"):
return pydantic.model_dump(model) # type: ignore return pydantic.model_dump(model) # type: ignore
else: else:
return model.model_dump() return model.model_dump()


def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str:
if content is None:
message_text = ""
elif isinstance(content, str):
message_text = content
elif isinstance(content, list):
# Assuming the list contains PromptMessageContent objects with a "data" attribute
message_text = "".join(
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
)
else:
message_text = str(content)
return message_text

+ 3
- 13
api/core/workflow/nodes/llm/node.py View File

) )
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil


def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
if isinstance(invoke_result, LLMResult): if isinstance(invoke_result, LLMResult):
content = invoke_result.message.content
if content is None:
message_text = ""
elif isinstance(content, str):
message_text = content
elif isinstance(content, list):
# Assuming the list contains PromptMessageContent objects with a "data" attribute
message_text = "".join(
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
)
else:
message_text = str(content)
message_text = convert_llm_result_chunk_to_str(invoke_result.message.content)


yield ModelInvokeCompletedEvent( yield ModelInvokeCompletedEvent(
text=message_text, text=message_text,
usage = None usage = None
finish_reason = None finish_reason = None
for result in invoke_result: for result in invoke_result:
text = result.delta.message.content
text = convert_llm_result_chunk_to_str(result.delta.message.content)
full_text += text full_text += text


yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])

Loading…
Cancel
Save