|
|
|
@@ -2,11 +2,14 @@ import tempfile |
|
|
|
from binascii import hexlify, unhexlify |
|
|
|
from collections.abc import Generator |
|
|
|
|
|
|
|
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output |
|
|
|
from core.model_manager import ModelManager |
|
|
|
from core.model_runtime.entities.llm_entities import ( |
|
|
|
LLMResult, |
|
|
|
LLMResultChunk, |
|
|
|
LLMResultChunkDelta, |
|
|
|
LLMResultChunkWithStructuredOutput, |
|
|
|
LLMResultWithStructuredOutput, |
|
|
|
) |
|
|
|
from core.model_runtime.entities.message_entities import ( |
|
|
|
PromptMessage, |
|
|
|
@@ -16,6 +19,7 @@ from core.model_runtime.entities.message_entities import ( |
|
|
|
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation |
|
|
|
from core.plugin.entities.request import ( |
|
|
|
RequestInvokeLLM, |
|
|
|
RequestInvokeLLMWithStructuredOutput, |
|
|
|
RequestInvokeModeration, |
|
|
|
RequestInvokeRerank, |
|
|
|
RequestInvokeSpeech2Text, |
|
|
|
@@ -85,6 +89,72 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): |
|
|
|
|
|
|
|
return handle_non_streaming(response) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def invoke_llm_with_structured_output( |
|
|
|
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLMWithStructuredOutput |
|
|
|
): |
|
|
|
""" |
|
|
|
invoke llm with structured output |
|
|
|
""" |
|
|
|
model_instance = ModelManager().get_model_instance( |
|
|
|
tenant_id=tenant.id, |
|
|
|
provider=payload.provider, |
|
|
|
model_type=payload.model_type, |
|
|
|
model=payload.model, |
|
|
|
) |
|
|
|
|
|
|
|
model_schema = model_instance.model_type_instance.get_model_schema(payload.model, model_instance.credentials) |
|
|
|
|
|
|
|
if not model_schema: |
|
|
|
raise ValueError(f"Model schema not found for {payload.model}") |
|
|
|
|
|
|
|
response = invoke_llm_with_structured_output( |
|
|
|
provider=payload.provider, |
|
|
|
model_schema=model_schema, |
|
|
|
model_instance=model_instance, |
|
|
|
prompt_messages=payload.prompt_messages, |
|
|
|
json_schema=payload.structured_output_schema, |
|
|
|
tools=payload.tools, |
|
|
|
stop=payload.stop, |
|
|
|
stream=True if payload.stream is None else payload.stream, |
|
|
|
user=user_id, |
|
|
|
model_parameters=payload.completion_params, |
|
|
|
) |
|
|
|
|
|
|
|
if isinstance(response, Generator): |
|
|
|
|
|
|
|
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]: |
|
|
|
for chunk in response: |
|
|
|
if chunk.delta.usage: |
|
|
|
llm_utils.deduct_llm_quota( |
|
|
|
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage |
|
|
|
) |
|
|
|
chunk.prompt_messages = [] |
|
|
|
yield chunk |
|
|
|
|
|
|
|
return handle() |
|
|
|
else: |
|
|
|
if response.usage: |
|
|
|
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) |
|
|
|
|
|
|
|
def handle_non_streaming( |
|
|
|
response: LLMResultWithStructuredOutput, |
|
|
|
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: |
|
|
|
yield LLMResultChunkWithStructuredOutput( |
|
|
|
model=response.model, |
|
|
|
prompt_messages=[], |
|
|
|
system_fingerprint=response.system_fingerprint, |
|
|
|
structured_output=response.structured_output, |
|
|
|
delta=LLMResultChunkDelta( |
|
|
|
index=0, |
|
|
|
message=response.message, |
|
|
|
usage=response.usage, |
|
|
|
finish_reason="", |
|
|
|
), |
|
|
|
) |
|
|
|
|
|
|
|
return handle_non_streaming(response) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding): |
|
|
|
""" |