| from collections.abc import Generator | from collections.abc import Generator | ||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||||
| from core.model_runtime.entities.message_entities import ( | from core.model_runtime.entities.message_entities import ( | ||||
| PromptMessage, | PromptMessage, | ||||
| SystemPromptMessage, | SystemPromptMessage, | ||||
| model_parameters=payload.completion_params, | model_parameters=payload.completion_params, | ||||
| tools=payload.tools, | tools=payload.tools, | ||||
| stop=payload.stop, | stop=payload.stop, | ||||
| stream=payload.stream or True, | |||||
| stream=True if payload.stream is None else payload.stream, | |||||
| user=user_id, | user=user_id, | ||||
| ) | ) | ||||
| else: | else: | ||||
| if response.usage: | if response.usage: | ||||
| LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) | LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) | ||||
| return response | |||||
| def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]: | |||||
| yield LLMResultChunk( | |||||
| model=response.model, | |||||
| prompt_messages=response.prompt_messages, | |||||
| system_fingerprint=response.system_fingerprint, | |||||
| delta=LLMResultChunkDelta( | |||||
| index=0, | |||||
| message=response.message, | |||||
| usage=response.usage, | |||||
| finish_reason="", | |||||
| ), | |||||
| ) | |||||
| return handle_non_streaming(response) | |||||
| @classmethod | @classmethod | ||||
| def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding): | def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding): |