| @@ -3,7 +3,7 @@ from binascii import hexlify, unhexlify | |||
| from collections.abc import Generator | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import ( | |||
| PromptMessage, | |||
| SystemPromptMessage, | |||
| @@ -46,7 +46,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): | |||
| model_parameters=payload.completion_params, | |||
| tools=payload.tools, | |||
| stop=payload.stop, | |||
| stream=payload.stream or True, | |||
| stream=True if payload.stream is None else payload.stream, | |||
| user=user_id, | |||
| ) | |||
| @@ -64,7 +64,21 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): | |||
| else: | |||
| if response.usage: | |||
| LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) | |||
| return response | |||
| def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]: | |||
| yield LLMResultChunk( | |||
| model=response.model, | |||
| prompt_messages=response.prompt_messages, | |||
| system_fingerprint=response.system_fingerprint, | |||
| delta=LLMResultChunkDelta( | |||
| index=0, | |||
| message=response.message, | |||
| usage=response.usage, | |||
| finish_reason="", | |||
| ), | |||
| ) | |||
| return handle_non_streaming(response) | |||
| @classmethod | |||
| def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding): | |||