| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 | import tempfile
from binascii import hexlify, unhexlify
from collections.abc import Generator
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
    PromptMessage,
    SystemPromptMessage,
    UserPromptMessage,
)
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.plugin.entities.request import (
    RequestInvokeLLM,
    RequestInvokeModeration,
    RequestInvokeRerank,
    RequestInvokeSpeech2Text,
    RequestInvokeSummary,
    RequestInvokeTextEmbedding,
    RequestInvokeTTS,
)
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm.node import LLMNode
from models.account import Tenant
class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
    @classmethod
    def invoke_llm(
        cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
    ) -> Generator[LLMResultChunk, None, None] | LLMResult:
        """
        invoke llm
        """
        model_instance = ModelManager().get_model_instance(
            tenant_id=tenant.id,
            provider=payload.provider,
            model_type=payload.model_type,
            model=payload.model,
        )
        # invoke model
        response = model_instance.invoke_llm(
            prompt_messages=payload.prompt_messages,
            model_parameters=payload.completion_params,
            tools=payload.tools,
            stop=payload.stop,
            stream=True if payload.stream is None else payload.stream,
            user=user_id,
        )
        if isinstance(response, Generator):
            def handle() -> Generator[LLMResultChunk, None, None]:
                for chunk in response:
                    if chunk.delta.usage:
                        LLMNode.deduct_llm_quota(
                            tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
                        )
                    yield chunk
            return handle()
        else:
            if response.usage:
                LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
            def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
                yield LLMResultChunk(
                    model=response.model,
                    prompt_messages=response.prompt_messages,
                    system_fingerprint=response.system_fingerprint,
                    delta=LLMResultChunkDelta(
                        index=0,
                        message=response.message,
                        usage=response.usage,
                        finish_reason="",
                    ),
                )
            return handle_non_streaming(response)
    @classmethod
    def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
        """
        invoke text embedding
        """
        model_instance = ModelManager().get_model_instance(
            tenant_id=tenant.id,
            provider=payload.provider,
            model_type=payload.model_type,
            model=payload.model,
        )
        # invoke model
        response = model_instance.invoke_text_embedding(
            texts=payload.texts,
            user=user_id,
        )
        return response
    @classmethod
    def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
        """
        invoke rerank
        """
        model_instance = ModelManager().get_model_instance(
            tenant_id=tenant.id,
            provider=payload.provider,
            model_type=payload.model_type,
            model=payload.model,
        )
        # invoke model
        response = model_instance.invoke_rerank(
            query=payload.query,
            docs=payload.docs,
            score_threshold=payload.score_threshold,
            top_n=payload.top_n,
            user=user_id,
        )
        return response
    @classmethod
    def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
        """
        invoke tts
        """
        model_instance = ModelManager().get_model_instance(
            tenant_id=tenant.id,
            provider=payload.provider,
            model_type=payload.model_type,
            model=payload.model,
        )
        # invoke model
        response = model_instance.invoke_tts(
            content_text=payload.content_text,
            tenant_id=tenant.id,
            voice=payload.voice,
            user=user_id,
        )
        def handle() -> Generator[dict, None, None]:
            for chunk in response:
                yield {"result": hexlify(chunk).decode("utf-8")}
        return handle()
    @classmethod
    def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
        """
        invoke speech2text
        """
        model_instance = ModelManager().get_model_instance(
            tenant_id=tenant.id,
            provider=payload.provider,
            model_type=payload.model_type,
            model=payload.model,
        )
        # invoke model
        with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
            temp.write(unhexlify(payload.file))
            temp.flush()
            temp.seek(0)
            response = model_instance.invoke_speech2text(
                file=temp,
                user=user_id,
            )
            return {
                "result": response,
            }
    @classmethod
    def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
        """
        invoke moderation
        """
        model_instance = ModelManager().get_model_instance(
            tenant_id=tenant.id,
            provider=payload.provider,
            model_type=payload.model_type,
            model=payload.model,
        )
        # invoke model
        response = model_instance.invoke_moderation(
            text=payload.text,
            user=user_id,
        )
        return {
            "result": response,
        }
    @classmethod
    def get_system_model_max_tokens(cls, tenant_id: str) -> int:
        """
        get system model max tokens
        """
        return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
    @classmethod
    def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
        """
        get prompt tokens
        """
        return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
    @classmethod
    def invoke_system_model(
        cls,
        user_id: str,
        tenant: Tenant,
        prompt_messages: list[PromptMessage],
    ) -> LLMResult:
        """
        invoke system model
        """
        return ModelInvocationUtils.invoke(
            user_id=user_id,
            tenant_id=tenant.id,
            tool_type=ToolProviderType.PLUGIN,
            tool_name="plugin",
            prompt_messages=prompt_messages,
        )
    @classmethod
    def invoke_summary(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSummary):
        """
        invoke summary
        """
        max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
        content = payload.text
        SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
retain the original meaning and keep the key points.
however, the text you got is too long, what you got is possible a part of the text.
Please summarize the text you got.
Here is the extra instruction you need to follow:
<extra_instruction>
{payload.instruction}
</extra_instruction>
"""
        if (
            cls.get_prompt_tokens(
                tenant_id=tenant.id,
                prompt_messages=[UserPromptMessage(content=content)],
            )
            < max_tokens * 0.6
        ):
            return content
        def get_prompt_tokens(content: str) -> int:
            return cls.get_prompt_tokens(
                tenant_id=tenant.id,
                prompt_messages=[
                    SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
                    UserPromptMessage(content=content),
                ],
            )
        def summarize(content: str) -> str:
            summary = cls.invoke_system_model(
                user_id=user_id,
                tenant=tenant,
                prompt_messages=[
                    SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
                    UserPromptMessage(content=content),
                ],
            )
            assert isinstance(summary.message.content, str)
            return summary.message.content
        lines = content.split("\n")
        new_lines: list[str] = []
        # split long line into multiple lines
        for i in range(len(lines)):
            line = lines[i]
            if not line.strip():
                continue
            if len(line) < max_tokens * 0.5:
                new_lines.append(line)
            elif get_prompt_tokens(line) > max_tokens * 0.7:
                while get_prompt_tokens(line) > max_tokens * 0.7:
                    new_lines.append(line[: int(max_tokens * 0.5)])
                    line = line[int(max_tokens * 0.5) :]
                new_lines.append(line)
            else:
                new_lines.append(line)
        # merge lines into messages with max tokens
        messages: list[str] = []
        for i in new_lines:  # type: ignore
            if len(messages) == 0:
                messages.append(i)  # type: ignore
            else:
                if len(messages[-1]) + len(i) < max_tokens * 0.5:  # type: ignore
                    messages[-1] += i  # type: ignore
                if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:  # type: ignore
                    messages.append(i)  # type: ignore
                else:
                    messages[-1] += i  # type: ignore
        summaries = []
        for i in range(len(messages)):
            message = messages[i]
            summary = summarize(message)
            summaries.append(summary)
        result = "\n".join(summaries)
        if (
            cls.get_prompt_tokens(
                tenant_id=tenant.id,
                prompt_messages=[UserPromptMessage(content=result)],
            )
            > max_tokens * 0.7
        ):
            return cls.invoke_summary(
                user_id=user_id,
                tenant=tenant,
                payload=RequestInvokeSummary(text=result, instruction=payload.instruction),
            )
        return result
 |